tf.compat.v1.estimator.tpu.TPUEstimatorSpec

Ops and objects returned from a model_fn and passed to TPUEstimator.

Migrate to TF2

TPU Estimator manages its own TensorFlow graph and session, so it is not compatible with TF2 behaviors. We recommend that you migrate to the newer tf.distribute.TPUStrategy. See the TPU guide for details.

Description

See EstimatorSpec for mode, predictions, loss, train_op, and export_outputs.

For evaluation, eval_metricsis a tuple of metric_fn and tensors, where metric_fn runs on CPU to generate metrics and tensors represents the Tensors transferred from TPU system to CPU host and passed to metric_fn. To be precise, TPU evaluation expects a slightly different signature from the tf.estimator.Estimator. While EstimatorSpec.eval_metric_ops expects a dict, TPUEstimatorSpec.eval_metrics is a tuple of metric_fn and tensors. The tensors could be a list of Tensors or dict of names to Tensors. The tensors usually specify the model logits, which are transferred back from TPU system to CPU host. All tensors must have be batch-major, i.e., the batch size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the metric_fn if tensors is list or keyword arguments if tensors is a dict. metric_fn takes the tensors and returns a dict from metric string name to the result of calling a metric function, namely a (metric_tensor, update_op) tuple. See TPUEstimator for MNIST example how to specify the eval_metrics.

scaffold_fn is a function running on CPU to generate the Scaffold. This function should not capture any Tensors in model_fn.

host_call is a tuple of a function and a list or dictionary of tensors to pass to that function and returns a list of Tensors. host_call currently works for train() and evaluate(). The Tensors returned by the function is executed on the CPU on every step, so there is communication overhead when sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The tensors are concatenated along their major (batch) dimension, and so must be >= rank 1. The host_call is useful for writing summaries with tf.contrib.summary.create_file_writer.

mode A namedtuple alias for field number 0
predictions A namedtuple alias for field number 1
loss A namedtuple alias for field number 2
train_op A namedtuple alias for field number 3
eval_metrics A namedtuple alias for field number 4
export_outputs A namedtuple alias for field number 5
scaffold_fn A namedtuple alias for field number 6
host_call A namedtuple alias for field number 7
training_hooks A namedtuple alias for field number 8
evaluation_hooks A namedtuple alias for field number 9
prediction_hooks A namedtuple alias for field number 10

Methods

as_estimator_spec

View source

Creates an equivalent EstimatorSpec used by CPU train/eval.