Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

Migrate from TPUEstimator to TPUStrategy

View on TensorFlow.org View source on GitHub Download notebook

This guide demonstrates how to migrate your workflows running on TPUs from TensorFlow 1's TPUEstimator API to TensorFlow 2's TPUStrategy API.

For end-to-end TensorFlow 2 examples, check out the Use TPUs guide—namely, the Classification on TPUs section—and the Solve GLUE tasks using BERT on TPU tutorial. You may also find the Distributed training guide useful, which covers all TensorFlow distribution strategies, including TPUStrategy.

Setup

Start with imports and a simple dataset for demonstration purposes:

import tensorflow as tf
import tensorflow.compat.v1 as tf1
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.7) or chardet (2.3.0)/charset_normalizer (2.0.9) doesn't match a supported version!
  RequestsDependencyWarning)
features = [[1., 1.5]]
labels = [[0.3]]
eval_features = [[4., 4.5]]
eval_labels = [[0.8]]

TensorFlow 1: Drive a model on TPUs with TPUEstimator

This section of the guide demonstrates how to perform training and evaluation with tf.compat.v1.estimator.tpu.TPUEstimator in TensorFlow 1.

To use a TPUEstimator, first define a few functions: an input function for the training data, an evaluation input function for the evaluation data, and a model function that tells the TPUEstimator how the training op is defined with the features and labels:

def _input_fn(params):
  dataset = tf1.data.Dataset.from_tensor_slices((features, labels))
  dataset = dataset.repeat()
  return dataset.batch(params['batch_size'], drop_remainder=True)

def _eval_input_fn(params):
  dataset = tf1.data.Dataset.from_tensor_slices((eval_features, eval_labels))
  dataset = dataset.repeat()
  return dataset.batch(params['batch_size'], drop_remainder=True)

def _model_fn(features, labels, mode, params):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

With those functions defined, create a tf.distribute.cluster_resolver.TPUClusterResolver that provides the cluster information, and a tf.compat.v1.estimator.tpu.RunConfig object. Along with the model function you have defined, you can now create a TPUEstimator. Here, you will simplify the flow by skipping checkpoint savings. Then, you will specify the batch size for both training and evaluation for the TPUEstimator.

cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='')
print("All devices: ", tf1.config.list_logical_devices('TPU'))
All devices:  []
tpu_config = tf1.estimator.tpu.TPUConfig(iterations_per_loop=10)
config = tf1.estimator.tpu.RunConfig(
    cluster=cluster_resolver,
    save_checkpoints_steps=None,
    tpu_config=tpu_config)
estimator = tf1.estimator.tpu.TPUEstimator(
    model_fn=_model_fn,
    config=config,
    train_batch_size=8,
    eval_batch_size=8)
WARNING:tensorflow:Estimator's model_fn (<function _model_fn at 0x7f085b6eae18>) includes params argument, but params are not passed to Estimator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpb4lechmg
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpb4lechmg', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.240.1.10:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.240.1.10:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.240.1.10:8470', '_evaluation_master': 'grpc://10.240.1.10:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=10, num_shards=None, num_cores_per_replica=None, per_host_input_for_training=2, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7f085d335d30>}
INFO:tensorflow:_TPUContext: eval_on_tpu True

Call TPUEstimator.train to begin training the model:

estimator.train(_input_fn, steps=1)
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.10:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3647470094639287222)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, -4023478403302218909)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 5381605205224317798)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 4519355694646195693)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, -6519377304537534401)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -7515633098448116369)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 5378185555831876363)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 152453060555458768)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 7866248824725174753)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 3490000913007233511)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -5503803846741663680)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Bypassing TPUEstimator hook
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:758: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Installing graceful shutdown hook.
INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0']
INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR

INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 7 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Outfeed finished for iteration (0, 0)
INFO:tensorflow:loss = 0.18597336, step = 1
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Loss for final step: 0.18597336.
INFO:tensorflow:training_loop marked as finished
<tensorflow_estimator.python.estimator.tpu.tpu_estimator.TPUEstimator at 0x7f085b6d43c8>

Then, call TPUEstimator.evaluate to evaluate the model using the evaluation data:

estimator.evaluate(_eval_input_fn, steps=1)
INFO:tensorflow:Could not find trained model in model_dir: /tmp/tmpb4lechmg, running initialization to evaluate.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:3406: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-12-04T14:10:30
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 10 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Outfeed finished for iteration (0, 0)
INFO:tensorflow:Evaluation [1/1]
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Inference Time : 11.00412s
INFO:tensorflow:Finished evaluation at 2021-12-04-14:10:41
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 39.13212
INFO:tensorflow:evaluation_loop marked as finished
{'loss': 39.13212, 'global_step': 1}

TensorFlow 2: Drive a model on TPUs with Keras Model.fit and TPUStrategy

In TensorFlow 2, to train on the TPU workers, use tf.distribute.TPUStrategy together with the Keras APIs for model definition and training/evaluation. (Refer to the Use TPUs guide for more examples of training with Keras Model.fit and a custom training loop (with tf.function and tf.GradientTape).)

Since you need to perform some initialization work to connect to the remote cluster and initialize the TPU workers, start by creating a TPUClusterResolver to provide the cluster information and connect to the cluster. (Learn more in the TPU initialization section of the Use TPUs guide.)

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.10:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.10:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]

Next, once your data is prepared, you will create a TPUStrategy, define a model, metrics, and an optimizer under the scope of this strategy.

To achieve comparable training speed with TPUStrategy, you should make sure to pick a number for steps_per_execution in Model.compile because it specifies the number of batches to run during each tf.function call, and is critical for performance. This argument is similar to iterations_per_loop used in a TPUEstimator. If you are using custom training loops, you should make sure multiple steps are run within the tf.function-ed training function. Go to the Improving performance with multiple steps inside tf.function section of the Use TPUs guide for more information.

tf.distribute.TPUStrategy can support bounded dynamic shapes, which is the case that the upper bound of the dynamic shape computation can be inferred. But dynamic shapes may introduce some performance overhead compared to static shapes. So, it is generally recommended to make your input shapes static if possible, especially in training. One common op that returns a dynamic shape is tf.data.Dataset.batch(batch_size), since the number of samples remaining in a stream might be less than the batch size. Therefore, when training on the TPU, you should use tf.data.Dataset.batch(..., drop_remainder=True) for best training performance.

dataset = tf.data.Dataset.from_tensor_slices(
    (features, labels)).shuffle(10).repeat().batch(
        8, drop_remainder=True).prefetch(2)
eval_dataset = tf.data.Dataset.from_tensor_slices(
    (eval_features, eval_labels)).batch(1, drop_remainder=True)

strategy = tf.distribute.TPUStrategy(cluster_resolver)
with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
  optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
  model.compile(optimizer, "mse", steps_per_execution=10)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

With that, you are ready to train the model with the training dataset:

model.fit(dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5
10/10 [==============================] - 1s 144ms/step - loss: 0.0708
Epoch 2/5
10/10 [==============================] - 0s 3ms/step - loss: 5.3889e-04
Epoch 3/5
10/10 [==============================] - 0s 3ms/step - loss: 5.5338e-06
Epoch 4/5
10/10 [==============================] - 0s 3ms/step - loss: 5.6931e-08
Epoch 5/5
10/10 [==============================] - 0s 3ms/step - loss: 5.8623e-10
<keras.callbacks.History at 0x7f064fed1860>

Finally, evaluate the model using the evaluation dataset:

model.evaluate(eval_dataset, return_dict=True)
1/1 [==============================] - 2s 2s/step - loss: 0.7638
{'loss': 0.7638343572616577}

Next steps

To learn more about TPUStrategy in TensorFlow 2, consider the following resources:

To learn more about customizing your training, refer to:

TPUs—Google's specialized ASICs for machine learning—are available through Google Colab, the TPU Research Cloud, and Cloud TPU.