Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Estimadores

Ver en TensorFlow.org Ver fuente en GitHub Descargar cuaderno

Este documento presenta tf.estimator una API de TensorFlow de alto nivel. Los estimadores encapsulan las siguientes acciones:

  • formación
  • evaluación
  • predicción
  • exportar para servir

TensorFlow implementa varios Estimadores prediseñados. Los estimadores personalizados todavía se admiten, pero principalmente como una medida de compatibilidad con versiones anteriores. Los estimadores personalizados no deben usarse para código nuevo . Todos los estimadores, ya sean prefabricados o personalizados, son clases basadas en la clase tf.estimator.Estimator .

Para obtener un ejemplo rápido, pruebe los tutoriales de Estimator . Para obtener una descripción general del diseño de la API, consulte el informe técnico .

Preparar

 pip install -q -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Ventajas

Similar a tf.keras.Model , un estimator es una abstracción a nivel de modelo. El tf.estimator proporciona algunas capacidades actualmente en desarrollo para tf.keras . Estos son:

  • Entrenamiento basado en servidor de parámetros
  • Integración TFX completa.

Capacidades de los estimadores

Los estimadores brindan los siguientes beneficios:

  • Puede ejecutar modelos basados ​​en Estimator en un host local o en un entorno distribuido de varios servidores sin cambiar su modelo. Además, puede ejecutar modelos basados ​​en Estimator en CPU, GPU o TPU sin recodificar su modelo.
  • Los estimadores proporcionan un ciclo de entrenamiento distribuido seguro que controla cómo y cuándo:
    • Cargar datos
    • manejar excepciones
    • crear archivos de puntos de control y recuperarse de fallas
    • guardar resúmenes para TensorBoard

Al escribir una aplicación con Estimadores, debe separar la canalización de entrada de datos del modelo. Esta separación simplifica los experimentos con diferentes conjuntos de datos.

Usando estimadores prediseñados

Los Estimadores prediseñados le permiten trabajar a un nivel conceptual mucho más alto que las API básicas de TensorFlow. Ya no tiene que preocuparse por crear el gráfico o las sesiones computacionales, ya que los estimadores se encargan de toda la "plomería" por usted. Además, los estimadores prefabricados le permiten experimentar con diferentes arquitecturas de modelos haciendo solo cambios mínimos en el código. tf.estimator.DNNClassifier , por ejemplo, es una clase de Estimator prefabricada que entrena modelos de clasificación basados ​​en redes neuronales densas y de avance.

Un programa de TensorFlow que se basa en un Estimador prefabricado generalmente consta de los siguientes cuatro pasos:

1. Escribe una función de entrada

Por ejemplo, puede crear una función para importar el conjunto de entrenamiento y otra función para importar el conjunto de prueba. Los estimadores esperan que sus entradas tengan el formato de un par de objetos:

  • Un diccionario en el que las claves son nombres de funciones y los valores son tensores (o SparseTensors) que contienen los datos de funciones correspondientes.
  • Un tensor que contiene una o más etiquetas

input_fn debería devolver un tf.data.Dataset que produzca pares en ese formato.

Por ejemplo, el siguiente código crea un tf.data.Dataset partir del archivo train.csv del conjunto de datos Titanic:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

input_fn se ejecuta en un tf.Graph y también puede devolver directamente un par (features_dics, labels) contiene tensores de gráfico, pero esto es propenso a errores fuera de casos simples como devolver constantes.

2. Defina las columnas de características.

Cada tf.feature_column identifica un nombre de función, su tipo y cualquier preprocesamiento de entrada.

Por ejemplo, el siguiente fragmento crea tres columnas de funciones.

  • El primero usa la característica de age directamente como una entrada de punto flotante.
  • El segundo usa la característica de class como entrada categórica.
  • El tercero usa embark_town como entrada categórica, pero usa el hashing trick para evitar la necesidad de enumerar las opciones y establecer el número de opciones.

Para obtener más información, consulte el tutorial de columnas de funciones .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Cree una instancia del Estimador prefabricado relevante.

Por ejemplo, aquí hay una instanciación de muestra de un Estimador LinearClassifier llamado LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_2fgw1gd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Para obtener más información, consulte el tutorial del clasificador lineal .

4. Llame a un método de capacitación, evaluación o inferencia.

Todos los estimadores proporcionan métodos de train , evaluate y predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6098593.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-10-15T01:25:18Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.63935s
INFO:tensorflow:Finished evaluation at 2020-10-15-01:25:19
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.7, accuracy_baseline = 0.603125, auc = 0.70968133, auc_precision_recall = 0.6162292, average_loss = 0.6068252, global_step = 100, label/mean = 0.396875, loss = 0.6068252, precision = 0.6962025, prediction/mean = 0.3867289, recall = 0.43307087
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmp_2fgw1gd/model.ckpt-100
accuracy : 0.7
accuracy_baseline : 0.603125
auc : 0.70968133
auc_precision_recall : 0.6162292
average_loss : 0.6068252
label/mean : 0.396875
loss : 0.6068252
precision : 0.6962025
prediction/mean : 0.3867289
recall : 0.43307087
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [0.6824188]
logistic : [0.6642783]
probabilities : [0.33572164 0.6642783 ]
class_ids : [1]
classes : [b'1']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Beneficios de los estimadores prefabricados

Los Estimadores prediseñados codifican las mejores prácticas y brindan los siguientes beneficios:

  • Mejores prácticas para determinar dónde deben ejecutarse las diferentes partes del gráfico computacional, implementando estrategias en una sola máquina o en un clúster.
  • Mejores prácticas para la redacción de eventos (resúmenes) y resúmenes de utilidad universal.

Si no utiliza Estimadores prediseñados, debe implementar las funciones anteriores usted mismo.

Estimadores personalizados

El corazón de cada Estimador, ya sea prefabricado o personalizado, es su función de modelo , model_fn , que es un método que crea gráficos para entrenamiento, evaluación y predicción. Cuando usa un Estimador prediseñado, alguien más ya implementó la función de modelo. Cuando confíe en un Estimador personalizado, debe escribir la función del modelo usted mismo.

Cree un estimador a partir de un modelo de Keras

Puede convertir modelos Keras existentes en Estimadores con tf.keras.estimator.model_to_estimator . Esto es útil si desea modernizar el código de su modelo, pero su canal de capacitación aún requiere Estimadores.

Cree una instancia de un modelo Keras MobileNet V2 y compile el modelo con el optimizador, la pérdida y las métricas para entrenar con:

importar tensorflow como tf importar tensorflow_datasets como tfds

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Cree un Estimator partir del modelo compilado de Keras. El estado del modelo inicial del modelo de Keras se conserva en el Estimator creado:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpy7tafocp
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpy7tafocp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Trate el Estimator derivado como lo haría con cualquier otro Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Para entrenar, llame a la función de tren de Estimator:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteMUGW8X/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.70231926.

INFO:tensorflow:Loss for final step: 0.70231926.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c4d4cf8>

De manera similar, para evaluar, llame a la función de evaluación del Estimador:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

{'accuracy': 0.565625, 'loss': 0.6713216, 'global_step': 50}

Para obtener más detalles, consulte la documentación de tf.keras.estimator.model_to_estimator .

Guardar puntos de control basados ​​en objetos con Estimator

Los estimadores guardan de forma predeterminada los puntos de control con nombres de variables en lugar del gráfico de objetos que se describe en la guía de puntos de control . tf.train.Checkpoint leerá los puntos de control basados ​​en nombres, pero los nombres de las variables pueden cambiar cuando se mueven partes de un modelo fuera del model_fn del Estimador. Para compatibilidad con versiones posteriores, guardar puntos de control basados ​​en objetos hace que sea más fácil entrenar un modelo dentro de un Estimador y luego usarlo fuera de uno.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 37.95615.

INFO:tensorflow:Loss for final step: 37.95615.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c477630>

tf.train.Checkpoint puede cargar los puntos de control del Estimador desde su model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Modelos guardados de estimadores

Los estimadores exportan modelos guardados a través de tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.35876164.

INFO:tensorflow:Loss for final step: 0.35876164.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f1a6c448b00>

Para guardar un Estimator , debe crear un serving_input_receiver . Esta función crea una parte de un tf.Graph que analiza los datos brutos recibidos por SavedModel.

El módulo tf.estimator.export contiene funciones para ayudar a construir estos receivers .

El siguiente código crea un receptor, basado en feature_columns , que acepta tf.Example protocolo tf.Example serializados, que se utilizan a menudo con tf-serve .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

También puede cargar y ejecutar ese modelo, desde python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.43590346, 0.5640965 ]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2578045]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5640965]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7984398 , 0.20156018]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.3765715]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2015602]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn permite crear funciones de entrada que toman tensores sin procesar en lugar de tf.train.Example s.

Usando tf.distribute.Strategy con Estimator (soporte limitado)

Consulte la guía de formación distribuida para obtener más información.

tf.estimator es una API de TensorFlow de entrenamiento distribuida que originalmente admitía el enfoque de servidor de parámetros asíncronos. tf.estimator ahora es compatible con tf.distribute.Strategy . Si está utilizando tf.estimator , puede cambiar al entrenamiento distribuido con muy pocos cambios en su código. Con esto, los usuarios de Estimator ahora pueden realizar entrenamiento distribuido sincrónico en múltiples GPU y múltiples trabajadores, así como también usar TPU. Sin embargo, este soporte en Estimator es limitado. Consulte la sección Qué se admite ahora a continuación para obtener más detalles.

El uso de tf.distribute.Strategy con Estimator es ligeramente diferente al caso de Keras. En lugar de utilizar strategy.scope , ahora pasamos el objeto de estrategia a RunConfig para el Estimator.

Aquí hay un fragmento de código que muestra esto con un Estimator LinearRegressor y MirroredStrategy LinearRegressor :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Not using Distribute Coordinator.

INFO:tensorflow:Not using Distribute Coordinator.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Aquí usamos un Estimador prediseñado, pero el mismo código también funciona con un Estimador personalizado. train_distribute determina cómo se distribuirá el entrenamiento y eval_distribute determina cómo se distribuirá la evaluación. Esta es otra diferencia con Keras, donde usamos la misma estrategia tanto para el entrenamiento como para la evaluación.

Ahora podemos entrenar y evaluar este Estimador con una función de entrada:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Otra diferencia a destacar aquí entre Estimator y Keras es el manejo de entrada. En Keras, mencionamos que cada lote del conjunto de datos se divide automáticamente en las múltiples réplicas. En Estimator, sin embargo, no dividimos automáticamente el lote ni dividimos automáticamente los datos entre diferentes trabajadores. Tiene control total sobre cómo desea que se distribuyan sus datos entre los trabajadores y los dispositivos, y debe proporcionar un input_fn para especificar cómo distribuir sus datos.

Su input_fn se llama una vez por trabajador, lo que proporciona un conjunto de datos por trabajador. Luego, un lote de ese conjunto de datos se alimenta a una réplica en ese trabajador, consumiendo así N lotes para N réplicas en 1 trabajador. En otras palabras, el conjunto de datos devuelto por input_fn debe proporcionar lotes de tamaño PER_REPLICA_BATCH_SIZE . Y el tamaño de lote global para un paso se puede obtener como PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

Al realizar la capacitación de varios trabajadores, debe dividir sus datos entre los trabajadores o mezclarlos con una semilla aleatoria en cada uno. Puede ver un ejemplo de cómo hacer esto en la Capacitación para varios trabajadores con Estimator .

Y de manera similar, también puede usar estrategias de servidor de parámetros y de trabajadores múltiples. El código sigue siendo el mismo, pero debe usar tf.estimator.train_and_evaluate y configurar las variables de entorno TF_CONFIG para cada binario que se ejecuta en su clúster.

¿Qué es compatible ahora?

Hay soporte limitado para entrenar con Estimator usando todas las estrategias excepto TPUStrategy . La capacitación y la evaluación básicas deberían funcionar, pero algunas características avanzadas como v1.train.Scaffold no lo hacen. También puede haber una serie de errores en esta integración. En este momento, no planeamos mejorar activamente este soporte, sino que nos centramos en Keras y el soporte de bucle de entrenamiento personalizado. Si es posible, debería preferir utilizar tf.distribute con esas API.

API de entrenamiento EspejoEstrategia TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
API de estimador Soporte limitado No soportado Soporte limitado Soporte limitado Soporte limitado

Ejemplos y tutoriales

A continuación, se muestran algunos ejemplos que muestran el uso integral de varias estrategias con Estimator:

  1. Capacitación de varios trabajadores con Estimator para capacitar a MNIST con varios trabajadores utilizando MultiWorkerMirroredStrategy .
  2. Ejemplo de extremo a extremo para la capacitación de varios trabajadores en tensorflow / ecosistema utilizando plantillas de Kubernetes. Este ejemplo comienza con un modelo de Keras y lo convierte en un Estimator usando la API tf.keras.estimator.model_to_estimator .
  3. Modelo oficial de ResNet50 , que se puede entrenar utilizando MirroredStrategy o MultiWorkerMirroredStrategy .