XlaSendTPUEmbeddingGradients

public final class XlaSendTPUEmbeddingGradients

An op that performs gradient updates of embedding tables.

The gradients argument is a TensorList having the same length and shapes as the return value of XlaRecvTPUEmbeddingActivations, but contains gradients of the model's loss with respect to the embedding activations. The embedding tables are updated from these gradients via the optimizer specified in the TPUEmbeddingConfiguration proto given to tpu.initialize_system.

Public Methods

static XlaSendTPUEmbeddingGradients
create(Scope scope, Iterable<Operand<Float>> gradients, Iterable<Operand<Float>> learningRates, Operand<?> deduplicationData, String config)
Factory method to create a class wrapping a new XlaSendTPUEmbeddingGradients operation.

Inherited Methods

Public Methods

public static XlaSendTPUEmbeddingGradients create (Scope scope, Iterable<Operand<Float>> gradients, Iterable<Operand<Float>> learningRates, Operand<?> deduplicationData, String config)

Factory method to create a class wrapping a new XlaSendTPUEmbeddingGradients operation.

Parameters
scope current scope
gradients A TensorList of gradients with which to update embedding tables.
learningRates A TensorList of learning rates used for updating the embedding tables via the optimizer. The length of the TensorList must be equal to the number of dynamic learning rate tags specified in the TPUEmbeddingConfiguration proto.
deduplicationData A Tensor with type=DT_VARIANT containing the deduplication data. The tensor is an XLA nested tuple containing N elements (where N is the ratio of the number of embedding to tensor cores per TPU chip). Each element of the nested tuple is a tuple of rank 1 tensors. Each tensor either contains indices (DT_UINT32) for embedding lookup on the TensorCore or weights (DT_FLOAT) to apply to the output of the embedding lookup operation.
config Serialized TPUEmbeddingConfiguration proto.
Returns
  • a new instance of XlaSendTPUEmbeddingGradients