tfrs.layers.dcn.Cross
Stay organized with collections
Save and categorize content based on your preferences.
Cross Layer in Deep & Cross Network to learn explicit feature interactions.
tfrs.layers.dcn.Cross(
projection_dim: Optional[int] = None,
diag_scale: Optional[float] = 0.0,
use_bias: bool = True,
preactivation: Optional[Union[str, tf.keras.layers.Activation]] = None,
kernel_initializer: Union[Text, tf.keras.initializers.Initializer] = 'truncated_normal',
bias_initializer: Union[Text, tf.keras.initializers.Initializer] = 'zeros',
kernel_regularizer: Union[Text, None, tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Union[Text, None, tf.keras.regularizers.Regularizer] = None,
**kwargs
)
Used in the notebooks
A layer that creates explicit and bounded-degree feature interactions
efficiently. The call
method accepts inputs
as a tuple of size 2
tensors. The first input x0
is the base layer that contains the original
features (usually the embedding layer); the second input xi
is the output
of the previous Cross
layer in the stack, i.e., the i-th Cross
layer. For the first Cross
layer in the stack, x0 = xi.
The output is x_{i+1} = x0 .* (W * xi + bias + diag_scale * xi) + xi,
where .* designates elementwise multiplication, W could be a full-rank
matrix, or a low-rank matrix U*V to reduce the computational cost, and
diag_scale increases the diagonal of W to improve training stability (
especially for the low-rank case).
Example |
# after embedding layer in a functional model:
input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64)
x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)
x1 = Cross()(x0, x0)
x2 = Cross()(x0, x1)
logits = tf.keras.layers.Dense(units=10)(x2)
model = tf.keras.Model(input, logits)
|
Args |
projection_dim
|
project dimension to reduce the computational cost.
Default is None such that a full (input_dim by input_dim ) matrix
W is used. If enabled, a low-rank matrix W = U*V will be used, where U
is of size input_dim by projection_dim and V is of size
projection_dim by input_dim . projection_dim need to be smaller
than input_dim /2 to improve the model efficiency. In practice, we've
observed that projection_dim = d/4 consistently preserved the
accuracy of a full-rank version.
|
diag_scale
|
a non-negative float used to increase the diagonal of the
kernel W by diag_scale , that is, W + diag_scale * I, where I is an
identity matrix.
|
use_bias
|
whether to add a bias term for this layer. If set to False,
no bias term will be used.
|
preactivation
|
Activation applied to output matrix of the layer, before
multiplication with the input. Can be used to control the scale of the
layer's outputs and improve stability.
|
kernel_initializer
|
Initializer to use on the kernel matrix.
|
bias_initializer
|
Initializer to use on the bias vector.
|
kernel_regularizer
|
Regularizer to use on the kernel matrix.
|
bias_regularizer
|
Regularizer to use on bias vector.
|
Input shape: A tuple of 2 (batch_size, input_dim
) dimensional inputs.
Output shape: A single (batch_size, input_dim
) dimensional output.
Methods
call
View source
call(
x0: tf.Tensor, x: Optional[tf.Tensor] = None
) -> tf.Tensor
Computes the feature cross.
Args |
x0
|
The input tensor
|
x
|
Optional second input tensor. If provided, the layer will compute
crosses between x0 and x; if not provided, the layer will compute
crosses between x0 and itself.
|
Returns |
Tensor of crosses.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tfrs.layers.dcn.Cross\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/recommenders/blob/v0.7.3/tensorflow_recommenders/layers/feature_interaction/dcn.py#L22-L208) |\n\nCross Layer in Deep \\& Cross Network to learn explicit feature interactions.\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tfrs.layers.feature_interaction.Cross`](https://www.tensorflow.org/recommenders/api_docs/python/tfrs/layers/dcn/Cross)\n\n\u003cbr /\u003e\n\n tfrs.layers.dcn.Cross(\n projection_dim: Optional[int] = None,\n diag_scale: Optional[float] = 0.0,\n use_bias: bool = True,\n preactivation: Optional[Union[str, tf.keras.layers.Activation]] = None,\n kernel_initializer: Union[Text, tf.keras.initializers.Initializer] = 'truncated_normal',\n bias_initializer: Union[Text, tf.keras.initializers.Initializer] = 'zeros',\n kernel_regularizer: Union[Text, None, tf.keras.regularizers.Regularizer] = None,\n bias_regularizer: Union[Text, None, tf.keras.regularizers.Regularizer] = None,\n **kwargs\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|---------------------------------------------------------------------------------------|\n| - [Deep \\& Cross Network (DCN)](https://www.tensorflow.org/recommenders/examples/dcn) |\n\nA layer that creates explicit and bounded-degree feature interactions\nefficiently. The `call` method accepts `inputs` as a tuple of size 2\ntensors. The first input `x0` is the base layer that contains the original\nfeatures (usually the embedding layer); the second input `xi` is the output\nof the previous `Cross` layer in the stack, i.e., the i-th `Cross`\nlayer. For the first `Cross` layer in the stack, x0 = xi.\n\nThe output is x_{i+1} = x0 .\\* (W \\* xi + bias + diag_scale \\* xi) + xi,\nwhere .\\* designates elementwise multiplication, W could be a full-rank\nmatrix, or a low-rank matrix U\\*V to reduce the computational cost, and\ndiag_scale increases the diagonal of W to improve training stability (\nespecially for the low-rank case).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| References ---------- ||\n|---|---|\n| \u003cbr /\u003e 1. [R. Wang et al.](https://arxiv.org/pdf/2008.13535.pdf) See Eq. (1) for full-rank and Eq. (2) for low-rank version. 2. [R. Wang et al.](https://arxiv.org/pdf/1708.05123.pdf) ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Example ------- ||\n|---|---|\n| \u003cbr /\u003e # after embedding layer in a functional model: input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6) x1 = Cross()(x0, x0) x2 = Cross()(x0, x1) logits = tf.keras.layers.Dense(units=10)(x2) model = tf.keras.Model(input, logits) \u003cbr /\u003e ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `projection_dim` | project dimension to reduce the computational cost. Default is `None` such that a full (`input_dim` by `input_dim`) matrix W is used. If enabled, a low-rank matrix W = U\\*V will be used, where U is of size `input_dim` by `projection_dim` and V is of size `projection_dim` by `input_dim`. `projection_dim` need to be smaller than `input_dim`/2 to improve the model efficiency. In practice, we've observed that `projection_dim` = d/4 consistently preserved the accuracy of a full-rank version. |\n| `diag_scale` | a non-negative float used to increase the diagonal of the kernel W by `diag_scale`, that is, W + diag_scale \\* I, where I is an identity matrix. |\n| `use_bias` | whether to add a bias term for this layer. If set to False, no bias term will be used. |\n| `preactivation` | Activation applied to output matrix of the layer, before multiplication with the input. Can be used to control the scale of the layer's outputs and improve stability. |\n| `kernel_initializer` | Initializer to use on the kernel matrix. |\n| `bias_initializer` | Initializer to use on the bias vector. |\n| `kernel_regularizer` | Regularizer to use on the kernel matrix. |\n| `bias_regularizer` | Regularizer to use on bias vector. |\n\n\u003cbr /\u003e\n\nInput shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs.\nOutput shape: A single (batch_size, `input_dim`) dimensional output.\n\nMethods\n-------\n\n### `call`\n\n[View source](https://github.com/tensorflow/recommenders/blob/v0.7.3/tensorflow_recommenders/layers/feature_interaction/dcn.py#L151-L186) \n\n call(\n x0: tf.Tensor, x: Optional[tf.Tensor] = None\n ) -\u003e tf.Tensor\n\nComputes the feature cross.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `x0` | The input tensor |\n| `x` | Optional second input tensor. If provided, the layer will compute crosses between x0 and x; if not provided, the layer will compute crosses between x0 and itself. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| Tensor of crosses. ||\n\n\u003cbr /\u003e"]]