A Keras layer for accelerating embedding lookups for large tables with TPU.

Used in the notebooks

Used in the guide Used in the tutorials

Feature and table configuration

When creating an instance of this layer, you must specify:

  1. The complete set of embedding tables,
  2. The features you expect to lookup in those tables and
  3. The optimizer(s) you wish to use on the tables.

See the documentation of tf.tpu.experimental.embedding.TableConfig and tf.tpu.experimental.embedding.FeatureConfig for more details on the complete set of options. We will cover the basic usage here.

table_config_one = tf.tpu.experimental.embedding.TableConfig(
table_config_two = tf.tpu.experimental.embedding.TableConfig(
feature_config = {
    'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
    'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
    'feature_three': tf.tpu.experimental.embedding.FeatureConfig(


An optimizer can be globally specified by passing one of the following types of input to the optimizer argument:

  1. A string, one of 'sgd', 'adagrad' or 'adam', which uses the given optimizer with the default parameters.
  2. An instance of a Keras optimizer.
  3. An instance of an optimizer class from the tf.tpu.experimental.embedding module.

You may also specify an optimizer as the table level via the optimizer argument of tf.tpu.experimental.embedding.TableConfig. This will completely override the global optimizer for this table. For performance reasons it is recommended that you minimize the total number of distinct optimizers.

Dynamic Learning Rate

Using a dynamic learning rate is supported for all optimizers, all other hyper parameters are static. There are two ways of specifying a dynamic learning rate in your optimizer:

  1. One of the objects in the tf.keras.optimizers.schedules name space.
  2. A python callable taking no parameters which returns a scalar tensor of type tf.float32.


This method of specifying a learning schedule is only possible when using a Keras optimizer. In this case, set the learning rate of the optimizer to your desired tf.keras.optimizers.schedules object.


This method can be used if you use a Keras optimizer or one of the optimizer classes in the tf.tpu.experimental.embedding namespace.

In either case you should create a callable function that returns a tensor. This function will be called once, but the ops it generates will be reevaluated each step. Thus it is recommended that you either create a tf.Variable representing your current step counter or use the iterations property of an optimizer you call apply_gradients on each trianing step.

with strategy.scope():
  step = tf.Variable(
      initial_value=0, trainable=False, dtype=tf.int64,

Model creation

For a functional style Keras model:

strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding_inputs = {
      'feature_one': tf.keras.Input(batch_size=1024, shape=(),
      'feature_two': tf.keras.Input(batch_size=1024, shape=(),
                                    dtype=tf.int32, ragged=True),
      'feature_three': tf.keras.Input(batch_size=1024, shape=(),
  # embedding, feature_config and embedding_inputs all have the same nested
  # structure.
  embedding = tpu_embedding_layer.TPUEmbedding(
  logits = tf.keras.layers.Dense(1)(tf.concat(tf.nest.flatten(embedding)))
  model = tf.keras.Model(embedding_inputs, logits)

For a subclass style model:

class ModelWithEmbeddings(tf.keras.Model):
  def __init__(self):
    self.embedding_layer = tpu_embedding_layer.TPUEmbedding(

  def call(self, inputs):
    embedding = self.embedding_layer(inputs)
    logits = tf.keras.layers.Dense(1)(tf.concat(tf.nest.flatten(embedding)))

with strategy.scope():
  model = ModelWithEmbeddings()

Input data

When creating a distributed dataset that is to be passed to be used with a model that contains a TPUEmbedding layer, a special option must be specified when calling any of the dataset distribution methods of TPUStrategy:

distributed_dataset = (
dataset_iterator = iter(distributed_dataset)

Different feature inputs can have different shapes. For dense and sparse tensor, rank 2 and above is supported. For ragged tensor, although only rank 2 is supported, you can specify the output shape to be rank 2 and above. The output shape specified in the FeatureConfig has the first priority. The input shape passed in build method has second priority and the input shapes auto detected from input feature has the lowest priority. The latter two will be converted to output shapes by omitting the last dimension. If the lower priority one has output shapes which don't match the former one. A ValueError will be raised. Only when the former one has undefined output shapes, the latter one can override.

Training and evaluation

To use this API on TPU you should use a custom training loop. Below is an example of a training and evaluation step:

def training_step(dataset_iterator, num_steps):
  def tpu_step(inputs):
    labels, features = inputs
    with tf.GradientTape() as tape:
      model_output = model(features)
      loss = ...  # some function of labels and model_output

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for _ in tf.range(num_steps):, args=(next(dataset_iterator), ))

def evaluation_step(dataset_iterator, num_steps):
  def tpu_step(inputs):
    labels, features = inputs
    model_output = model(features)
    # Insert your evaluation code here.

  for _ in tf.range(num_steps):, args=(next(dataset_iterator), ))

In the above examples, we assume that the user has a dataset which returns a tuple where the second element of the tuple matches the structure of what was passed as the feature_config argument to the object initializer. Also we utilize tf.range to get a tf.while_loop in order to increase performance.

The embedding layer does not affect checkpointing; simply checkpoint your model as normal, remembering that if you passed either a Keras optimizer or an optimizer converted from a Keras optimizer via translate_keras_optimizer you must checkpoint the optimizer to ensure that your slot variables are saved.

checkpoint = tf.train.Checkpoint(model=model)


Serving is accomplished through the tf.saved_model API. The model may be exported directly from training.

First we write a tf.function that represents the serving graph. Typically this may take as input a string tensor containing protos that are parsed into tensors and then passed to the model. I.e.

            shape=[None], dtype=tf.string, name='examples')}])
def serve_examples(examples):
  input_data = ...  # parse the examples tensor to produce input tensors.
  return model(input_data),
                    signatures={'serving': serve_examples})

The exported model can now be loaded (in python or c) and used for serving:

imported = tf.saved_model.load(...)
predict_fn = imported.signatures['serving']

Using this layer on CPU

This layer can also be instantiated under a CPU strategy and used for local testing/training. The model created in such a way are checkpoint compatible with models created under TPUStrategy. In order to achieve checkpoint compatibility, you must use a Keras optimizers (or ones converted by translate_keras_optimizer) as your optimizers.

In the simplest case, where you use the same optimizer for your embedding and dense layers, the training_step above will function exactly the same in both situations.

If you use a separate Keras optimizer for your embedding layers (e.g. you want a different hyper parameter setting or an entirely different algorithm), special care must be observed to keep things the same. To understand why, there are a few technical details you need to know:

When created under TPUStrategy the underlying table variables are not considered trainable and are not available under model.trainable_variables. The main reason for this is that the table variables are just a stand-in for the real data which lives in the HBM of the TPU. These variables are stale and are only updated when saving and restoring checkpoints.

Because of this a standard optimizer.apply_gradient will not work on these variables. Instead a separate virtual trainable variable is added to the list of trainable variables and simply computing the gradient of this variable will cause the gradient for the embeddings to be computed and the optimizer applied.

When created under a CPU strategy, the table variables are created normally and are part of the model's trainiable variables. In this case, if you are using a different optimizer to embedding tables, you must manually partition the variables and gradients so that you can use the Keras optmizer you created for embedding tables on the tables.


class ModelWithSeparateOptimizer(tf.keras.Model):
  def __init__(self, optimizer):
    self.embedding_layer = tpu_embedding_layer.TPUEmbedding(

  def call(self, inputs):
    embedding = self.embedding_layer(inputs)
    logits = tf.keras.layers.Dense(1)(tf.concat(tf.nest.flatten(embedding)))

with strategy.scope():
  embedding_optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
  dense_optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
  model = ModelWithSeparateOptimizer(embedding_optimizer)

def training_step(dataset_iterator, num_steps):
  def tpu_step(inputs):
    labels, features = inputs
    with tf.GradientTape() as tape:
      model_output = model(features)
      loss = ...  # some function of labels and model_output

    gradients = tape.gradient(loss, model.trainable_variables)
    grads_and_vars = zip(gradients, model.trainable_variables)

    # Note the use of 'id' here: 'x in y' uses x's equality method and if x is
    # a tensor this tf.math.equal rather than python object equality.
    embedding_var_ids = [
        id(v) for v in model.embedding_layer.trainable_variables]
    dense_grads_and_vars = [
        (g, v) for g, v in grads_and_vars
        if id(v) not in embedding_var_ids]

    embedding_grads_and_vars = [
        (g, v) for g, v in grads_and_vars
        if id(v) in embedding_var_ids]

  for _ in tf.range(num_steps):, args=(next(dataset_iterator), ))

The above training step works both on TPU and on CPU.

Using this layer on TPU without embedding lookup accelerator.

This layer can also be initialized under TPUs without embedding lookup accelerators. There is no change required to the client code as the layer can auto switch between different mid level APIs based on the TPU hardware. You can also force the layer to run without acceleration by overriding the embedding feature to "UNSUPPORTED". This might be helpful when your table is relatively small.

Note that instead of sharding the table across devices, the table will be replicated across them.

feature_config A nested structure of tf.tpu.experimental.embedding.FeatureConfig configs.
optimizer An instance of one of tf.tpu.experimental.embedding.SGD, tf.tpu.experimental.embedding.Adagrad or tf.tpu.experimental.embedding.Adam, a Keras optimizer or a string name of an optimizer (see tf.keras.optimizers.get). Or, if not created under a TPU strategy, None, which will avoid creation of the optimizer slot variable do reduce memory consumption during export.
pipeline_execution_with_tensor_core If True, the TPU embedding computations will overlap with the TensorCore computations (and hence will be one step old with potential correctness drawbacks). Set to True for improved performance.
batch_size Batch size of the input feature. Deprecated, support backward compatibility.
embedding_feature EmbeddingFeature enum, inidicating which version of TPU hardware the layer should run on.

embedding_tables A mapping from table configs to tables.

When instantiated under a TPU strategy, this returns a sharded variable. This variable is strictly a placeholder used for saving and restoring. Attempting to assign values to this variable will not update the actual embedding tables and reading may result in reading a stale copy of the table. Should not be used for actual computation, only for exporting the model for serving.



View source

Look up features in the embedding tables and combine using weights.

features a nested structure of Tensors, SparseTensors or RaggedTensors with the same structure as feature_config. These tensors are used as ids to lookup rows in the embedding tables using the config as specified in the corresponding entry of feature_config. You can mix Tensors and SparseTensors, or Tensors and RaggedTensors, but not SparseTensors and RaggedTensors.
weights None, or a nested structure of Tensors,SparseTensors orRaggedTensors or None matching features. These are the weights used when combining the looked up rows for a given feature and examples. If None, weights of 1 will be used. </td> </tr><tr> <td>serving_config` A nested structure of tf.tpu.experimental.embedding.FeatureConfig objects. If not None, this layer uses CPU based lookup using serving_config and the current set of embedding tables.

The combined embedding activations for the input ids passed in via features.

RuntimeError If layer is not created under a TPU strategy and is called under a TPU strategy.