You can first train your model using the TPUEmbedding class and save the
checkpoint. Then use this class to restore the checkpoint to do serving.
First train a model and save the checkpoint.
model=model_fn(...)strategy=tf.distribute.TPUStrategy(...)withstrategy.scope():embedding=tf.tpu.experimental.embedding.TPUEmbedding(feature_config=feature_config,optimizer=tf.tpu.experimental.embedding.SGD(0.1))# Your custom training code.checkpoint=tf.train.Checkpoint(model=model,embedding=embedding)checkpoint.save(...)
Then restore the checkpoint and do serving.
# Restore the model on CPU.model=model_fn(...)embedding=tf.tpu.experimental.embedding.TPUEmbeddingForServing(feature_config=feature_config,optimizer=tf.tpu.experimental.embedding.SGD(0.1))checkpoint=tf.train.Checkpoint(model=model,embedding=embedding)checkpoint.restore(...)result=embedding(...)table=embedding.embedding_table
If not None, a nested structure of tf.Tensors,
tf.SparseTensors or tf.RaggedTensors, matching the above, except
that the tensors should be of float type (and they will be downcast to
tf.float32). For tf.SparseTensors we assume the indices are the
same for the parallel entries from features and similarly for
tf.RaggedTensors we assume the row_splits are the same.
Returns
A nested structure of Tensors with the same structure as input features.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tf.tpu.experimental.embedding.TPUEmbeddingForServing\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/tpu/tpu_embedding_for_serving.py#L44-L285) |\n\nThe TPUEmbedding mid level API running on CPU for serving.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.tpu.experimental.embedding.TPUEmbeddingForServing`](https://www.tensorflow.org/api_docs/python/tf/tpu/experimental/embedding/TPUEmbeddingForServing)\n\n\u003cbr /\u003e\n\n tf.tpu.experimental.embedding.TPUEmbeddingForServing(\n feature_config: Union[../../../../tf/tpu/experimental/embedding/FeatureConfig, Iterable],\n optimizer: Optional[tpu_embedding_v2_utils._Optimizer],\n experimental_sparsecore_restore_info: Optional[Dict[str, Any]] = None\n )\n\n| **Note:** This class is intended to be used for embedding tables that are trained on TPU and to be served on CPU. Therefore the class should be only initialized under non-TPU strategy. Otherwise an error will be raised.\n\nYou can first train your model using the TPUEmbedding class and save the\ncheckpoint. Then use this class to restore the checkpoint to do serving.\n\nFirst train a model and save the checkpoint. \n\n model = model_fn(...)\n strategy = tf.distribute.TPUStrategy(...)\n with strategy.scope():\n embedding = tf.tpu.experimental.embedding.TPUEmbedding(\n feature_config=feature_config,\n optimizer=tf.tpu.experimental.embedding.SGD(0.1))\n\n # Your custom training code.\n\n checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)\n checkpoint.save(...)\n\nThen restore the checkpoint and do serving. \n\n\n # Restore the model on CPU.\n model = model_fn(...)\n embedding = tf.tpu.experimental.embedding.TPUEmbeddingForServing(\n feature_config=feature_config,\n optimizer=tf.tpu.experimental.embedding.SGD(0.1))\n\n checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)\n checkpoint.restore(...)\n\n result = embedding(...)\n table = embedding.embedding_table\n\n| **Note:** This class can also be used to do embedding training on CPU. But it requires the conversion between keras optimizer and embedding optimizers so that the slot variables can stay consistent between them.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `feature_config` | A nested structure of [`tf.tpu.experimental.embedding.FeatureConfig`](../../../../tf/tpu/experimental/embedding/FeatureConfig) configs. |\n| `optimizer` | An instance of one of [`tf.tpu.experimental.embedding.SGD`](../../../../tf/tpu/experimental/embedding/SGD), [`tf.tpu.experimental.embedding.Adagrad`](../../../../tf/tpu/experimental/embedding/Adagrad) or [`tf.tpu.experimental.embedding.Adam`](../../../../tf/tpu/experimental/embedding/Adam). When not created under TPUStrategy may be set to None to avoid the creation of the optimizer slot variables, useful for optimizing memory consumption when exporting the model for serving where slot variables aren't needed. |\n| `experimental_sparsecore_restore_info` | Information from the sparse core training, required to restore from checkpoint for serving (like number of TPU devices used `num_tpu_devices`.) |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|----------------|-------------------------------|\n| `RuntimeError` | If created under TPUStrategy. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|--------------------|-------------------------------------------------------------|\n| `embedding_tables` | Returns a dict of embedding tables, keyed by `TableConfig`. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `build`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/tpu/tpu_embedding_for_serving.py#L170-L173) \n\n build()\n\nCreate variables and slots variables for TPU embeddings.\n\n### `embedding_lookup`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/tpu/tpu_embedding_for_serving.py#L263-L285) \n\n embedding_lookup(\n features: Any, weights: Optional[Any] = None\n ) -\u003e Any\n\nApply standard lookup ops on CPU.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `features` | A nested structure of [`tf.Tensor`](../../../../tf/Tensor)s, [`tf.SparseTensor`](../../../../tf/sparse/SparseTensor)s or [`tf.RaggedTensor`](../../../../tf/RaggedTensor)s, with the same structure as `feature_config`. Inputs will be downcast to [`tf.int32`](../../../../tf#int32). Only one type out of [`tf.SparseTensor`](../../../../tf/sparse/SparseTensor) or [`tf.RaggedTensor`](../../../../tf/RaggedTensor) is supported per call. |\n| `weights` | If not `None`, a nested structure of [`tf.Tensor`](../../../../tf/Tensor)s, [`tf.SparseTensor`](../../../../tf/sparse/SparseTensor)s or [`tf.RaggedTensor`](../../../../tf/RaggedTensor)s, matching the above, except that the tensors should be of float type (and they will be downcast to [`tf.float32`](../../../../tf#float32)). For [`tf.SparseTensor`](../../../../tf/sparse/SparseTensor)s we assume the `indices` are the same for the parallel entries from `features` and similarly for [`tf.RaggedTensor`](../../../../tf/RaggedTensor)s we assume the row_splits are the same. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A nested structure of Tensors with the same structure as input features. ||\n\n\u003cbr /\u003e\n\n### `__call__`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/tpu/tpu_embedding_base.py#L135-L139) \n\n __call__(\n features: Any, weights: Optional[Any] = None\n ) -\u003e Any\n\nCall the mid level api to do embedding lookup."]]