Implementation of mixture Expectation-Maximization loss
(Yan et al, 2018). This loss assumes that the clicks in a session
are generated by one of mixture models.
# Using ragged tensorsy_true=tf.ragged.constant([[1.,0.],[0.,1.,0.]])y_pred=tf.ragged.constant([[[0.6,0.9],[0.8,0.2]],[[0.5,0.9],[0.8,0.2],[0.4,0.8]]])loss=tfr.keras.losses.MixtureEMLoss(ragged=True)loss(y_true,y_pred).numpy()1.909512
[null,null,["Last updated 2023-08-18 UTC."],[],[],null,["# tfr.keras.losses.MixtureEMLoss\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L1307-L1382) |\n\nComputes mixture EM loss between `y_true` and `y_pred`. \n\n tfr.keras.losses.MixtureEMLoss(\n reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,\n name: Optional[str] = None,\n lambda_weight: Optional[losses_impl._LambdaWeight] = None,\n temperature: float = 1.0,\n alpha: float = 1.0,\n ragged: bool = False\n )\n\nImplementation of mixture Expectation-Maximization loss\n([Yan et al, 2018](https://research.google/pubs/pub51296/)). This loss assumes that the clicks in a session\nare generated by one of mixture models.\n| **Note:** This loss should be called with a `logits` tensor of shape `[batch_size, list_size, model_num]`. The elements in the last dimension of `logits` represent models to be mixed.\n\n#### Standalone usage:\n\n y_true = [[1., 0.]]\n y_pred = [[[0.6, 0.9], [0.8, 0.2]]]\n loss = tfr.keras.losses.MixtureEMLoss()\n loss(y_true, y_pred).numpy()\n 1.3198698\n\n # Using ragged tensors\n y_true = tf.ragged.constant([[1., 0.], [0., 1., 0.]])\n y_pred = tf.ragged.constant([[[0.6, 0.9], [0.8, 0.2]],\n [[0.5, 0.9], [0.8, 0.2], [0.4, 0.8]]])\n loss = tfr.keras.losses.MixtureEMLoss(ragged=True)\n loss(y_true, y_pred).numpy()\n 1.909512\n\nUsage with the `compile()` API: \n\n model.compile(optimizer='sgd', loss=tfr.keras.losses.MixtureEMLoss())\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| References ---------- ||\n|---|---|\n| \u003cbr /\u003e - [Revisiting two tower models for unbiased learning to rank, Yan et al, 2022](https://research.google/pubs/pub51296/). ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `reduction` | (Optional) The [`tf.keras.losses.Reduction`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction) to use (see [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss)). |\n| `name` | (Optional) The name for the op. |\n| `lambda_weight` | (Optional) A lambdaweight to apply to the loss. Can be one of [`tfr.keras.losses.DCGLambdaWeight`](../../../tfr/keras/losses/DCGLambdaWeight), [`tfr.keras.losses.NDCGLambdaWeight`](../../../tfr/keras/losses/NDCGLambdaWeight), or, [`tfr.keras.losses.PrecisionLambdaWeight`](../../../tfr/keras/losses/PrecisionLambdaWeight). |\n| `temperature` | (Optional) The temperature to use for scaling the logits. |\n| `alpha` | (Optional) The smooth factor of the probability. |\n| `ragged` | (Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `from_config`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L742-L752) \n\n @classmethod\n from_config(\n config, custom_objects=None\n )\n\nInstantiates a `Loss` from its config (output of `get_config()`).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|----------|---------------------------|\n| `config` | Output of `get_config()`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A `Loss` instance. ||\n\n\u003cbr /\u003e\n\n### `get_config`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L1377-L1382) \n\n get_config() -\u003e Dict[str, Any]\n\nReturns the config dictionary for a `Loss` instance.\n\n### `__call__`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L262-L270) \n\n __call__(\n y_true: ../../../tfr/keras/model/TensorLike,\n y_pred: ../../../tfr/keras/model/TensorLike,\n sample_weight: Optional[utils.TensorLike] = None\n ) -\u003e tf.Tensor\n\nSee tf.keras.losses.Loss."]]