![]() |
![]() |
![]() |
![]() |
Canned (or Premade) Estimators have traditionally been used in TensorFlow 1 as quick and easy ways to train models for a variety of typical use cases. TensorFlow 2 provides straightforward approximate substitutes for a number of them by way of Keras models. For those canned estimators that do not have built-in TensorFlow 2 substitutes, you can still build your own replacement fairly easily.
This guide walks through a few examples of direct equivalents and custom substitutions to demonstrate how TensorFlow 1's tf.estimator
-derived models can be migrated to TF2 with Keras.
Namely, this guide includes examples for migrating:
- From
tf.estimator
'sLinearEstimator
,Classifier
orRegressor
in TensorFlow 1 to Kerastf.compat.v1.keras.models.LinearModel
in TensorFlow 2 - From
tf.estimator
'sDNNEstimator
,Classifier
orRegressor
in TensorFlow 1 to a custom Keras DNN ModelKeras in TensorFlow 2 - From
tf.estimator
'sDNNLinearCombinedEstimator
,Classifier
orRegressor
in TensorFlow 1 totf.compat.v1.keras.models.WideDeepModel
in TensorFlow 2 - From
tf.estimator
'sBoostedTreesEstimator
,Classifier
orRegressor
in TensorFlow 1 totf.compat.v1.keras.models.WideDeepModel
in TensorFlow 2
A common precursor to the training of a model is feature preprocessing, which is done for TensorFlow 1 Estimator models with tf.feature_column
. For more information on feature preprocessing in TensorFlow 2, see this guide on migrating feature columns.
Setup
Start with a couple of necessary TensorFlow imports,
pip install tensorflow_decision_forests
import keras
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_decision_forests as tfdf
prepare some simple data for demonstration from the standard Titanic dataset,
x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_train.drop(['embark_town', 'deck'], axis=1, inplace=True)
x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True)
y_train = x_train.pop('survived')
y_eval = x_eval.pop('survived')
# Data setup for TensorFlow 1 with `tf.estimator`
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32)
FEATURE_NAMES = [
'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone'
]
feature_columns = []
for fn in FEATURE_NAMES:
feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32)
feature_columns.append(feat_col)
and create a method to instantiate a simplistic sample optimizer to use with our various TensorFlow 1 Estimator and TensorFlow 2 Keras models.
def create_sample_optimizer(tf_version):
if tf_version == 'tf1':
optimizer = lambda: tf.keras.optimizers.Ftrl(
l1_regularization_strength=0.001,
learning_rate=tf1.train.exponential_decay(
learning_rate=0.1,
global_step=tf1.train.get_global_step(),
decay_steps=10000,
decay_rate=0.9))
elif tf_version == 'tf2':
optimizer = tf.keras.optimizers.Ftrl(
l1_regularization_strength=0.001,
learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9))
return optimizer
Example 1: Migrating from LinearEstimator
TF1: Using LinearEstimator
In TensorFlow 1, you can use tf.estimator.LinearEstimator
to create a baseline linear model for regression and classification problems.
linear_estimator = tf.estimator.LinearEstimator(
head=tf.estimator.BinaryClassHead(),
feature_columns=feature_columns,
optimizer=create_sample_optimizer('tf1'))
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpowoz1d3v INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpowoz1d3v', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
linear_estimator.train(input_fn=_input_fn, steps=100)
linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/ftrl.py:153: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpowoz1d3v/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmpowoz1d3v/model.ckpt-0.meta INFO:tensorflow:300 INFO:tensorflow:/tmpfs/tmp/tmpowoz1d3v/model.ckpt-0.index INFO:tensorflow:300 INFO:tensorflow:/tmpfs/tmp/tmpowoz1d3v/model.ckpt-0.data-00000-of-00001 INFO:tensorflow:300 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20... INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpowoz1d3v/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmpowoz1d3v/model.ckpt-20.index INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmpowoz1d3v/model.ckpt-20.data-00000-of-00001 INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmpowoz1d3v/model.ckpt-20.meta INFO:tensorflow:300 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20... INFO:tensorflow:Loss for final step: 0.552688. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-06-08T01:32:43 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpowoz1d3v/model.ckpt-20 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Inference Time : 0.51412s INFO:tensorflow:Finished evaluation at 2022-06-08-01:32:43 INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70075756, accuracy_baseline = 0.625, auc = 0.75472915, auc_precision_recall = 0.65362054, average_loss = 0.5759378, global_step = 20, label/mean = 0.375, loss = 0.5704811, precision = 0.6388889, prediction/mean = 0.41331065, recall = 0.46464646 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpowoz1d3v/model.ckpt-20 {'accuracy': 0.70075756, 'accuracy_baseline': 0.625, 'auc': 0.75472915, 'auc_precision_recall': 0.65362054, 'average_loss': 0.5759378, 'label/mean': 0.375, 'loss': 0.5704811, 'precision': 0.6388889, 'prediction/mean': 0.41331065, 'recall': 0.46464646, 'global_step': 20}
TF2: Using Keras LinearModel
In TensorFlow 2, you can create an instance of the Keras tf.compat.v1.keras.models.LinearModel
which is the substitute to the tf.estimator.LinearEstimator
. The tf.compat.v1.keras
path is used to signify that the pre-made model exists for compatibility.
linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10)
linear_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10 20/20 [==============================] - 0s 2ms/step - loss: 6.3420 - accuracy: 0.6443 Epoch 2/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2180 - accuracy: 0.6555 Epoch 3/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2066 - accuracy: 0.6762 Epoch 4/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2055 - accuracy: 0.6826 Epoch 5/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1916 - accuracy: 0.6970 Epoch 6/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1827 - accuracy: 0.7337 Epoch 7/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1755 - accuracy: 0.7544 Epoch 8/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1716 - accuracy: 0.7671 Epoch 9/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1674 - accuracy: 0.7847 Epoch 10/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1667 - accuracy: 0.8022 9/9 [==============================] - 0s 2ms/step - loss: 0.1929 - accuracy: 0.7576 {'loss': 0.19286951422691345, 'accuracy': 0.7575757503509521}
Example 2: Migrating from DNNEstimator
TF1: Using DNNEstimator
In TensorFlow 1, you can use tf.estimator.DNNEstimator
to create a baseline DNN model for regression and classification problems.
dnn_estimator = tf.estimator.DNNEstimator(
head=tf.estimator.BinaryClassHead(),
feature_columns=feature_columns,
hidden_units=[128],
activation_fn=tf.nn.relu,
optimizer=create_sample_optimizer('tf1'))
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpeo0dgnz2 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpeo0dgnz2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
dnn_estimator.train(input_fn=_input_fn, steps=100)
dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 2022-06-08 01:32:45.153411: W tensorflow/core/common_runtime/forward_type_inference.cc:231] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpeo0dgnz2/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmpeo0dgnz2/model.ckpt-0.meta INFO:tensorflow:200 INFO:tensorflow:/tmpfs/tmp/tmpeo0dgnz2/model.ckpt-0.index INFO:tensorflow:200 INFO:tensorflow:/tmpfs/tmp/tmpeo0dgnz2/model.ckpt-0.data-00000-of-00001 INFO:tensorflow:200 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.579392, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20... INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpeo0dgnz2/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmpeo0dgnz2/model.ckpt-20.index INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmpeo0dgnz2/model.ckpt-20.data-00000-of-00001 INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmpeo0dgnz2/model.ckpt-20.meta INFO:tensorflow:200 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20... INFO:tensorflow:Loss for final step: 0.56485385. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-06-08T01:32:46 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpeo0dgnz2/model.ckpt-20 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Inference Time : 0.46785s INFO:tensorflow:Finished evaluation at 2022-06-08-01:32:47 INFO:tensorflow:Saving dict for global step 20: accuracy = 0.6818182, accuracy_baseline = 0.625, auc = 0.7034282, auc_precision_recall = 0.61065406, average_loss = 0.60924685, global_step = 20, label/mean = 0.375, loss = 0.60097164, precision = 0.5862069, prediction/mean = 0.40602148, recall = 0.5151515 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpeo0dgnz2/model.ckpt-20 {'accuracy': 0.6818182, 'accuracy_baseline': 0.625, 'auc': 0.7034282, 'auc_precision_recall': 0.61065406, 'average_loss': 0.60924685, 'label/mean': 0.375, 'loss': 0.60097164, 'precision': 0.5862069, 'prediction/mean': 0.40602148, 'recall': 0.5151515, 'global_step': 20}
TF2: Using Keras to Create a Custom DNN Model
In TensorFlow 2, you can create a custom DNN model to substitute for one generated by tf.estimator.DNNEstimator
, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).
A similar workflow can be used to replace tf.estimator.experimental.RNNEstimator
with a Keras RNN Model. Keras provides a number of built-in, customizable choices by way of tf.keras.layers.RNN
, tf.keras.layers.LSTM
, and tf.keras.layers.GRU
- see here for more details.
dnn_model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
dnn_model.fit(x_train, y_train, epochs=10)
dnn_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10 20/20 [==============================] - 0s 2ms/step - loss: 350.4538 - accuracy: 0.4848 Epoch 2/10 20/20 [==============================] - 0s 2ms/step - loss: 0.5715 - accuracy: 0.5375 Epoch 3/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2917 - accuracy: 0.6252 Epoch 4/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2502 - accuracy: 0.6411 Epoch 5/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2560 - accuracy: 0.6619 Epoch 6/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2157 - accuracy: 0.6746 Epoch 7/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2144 - accuracy: 0.6874 Epoch 8/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1924 - accuracy: 0.7193 Epoch 9/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1905 - accuracy: 0.7368 Epoch 10/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1793 - accuracy: 0.7528 9/9 [==============================] - 0s 2ms/step - loss: 0.1921 - accuracy: 0.7348 {'loss': 0.19207113981246948, 'accuracy': 0.7348484992980957}
Example 3: Migrating from DNNLinearCombinedEstimator
TF1: Using DNNLinearCombinedEstimator
In TensorFlow 1, you can use tf.estimator.DNNLinearCombinedEstimator
to create a baseline combined model for regression and classification problems with customization capacity for both its linear and DNN components.
optimizer = create_sample_optimizer('tf1')
combined_estimator = tf.estimator.DNNLinearCombinedEstimator(
head=tf.estimator.BinaryClassHead(),
# Wide settings
linear_feature_columns=feature_columns,
linear_optimizer=optimizer,
# Deep settings
dnn_feature_columns=feature_columns,
dnn_hidden_units=[128],
dnn_optimizer=optimizer)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpc_ft_xly INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpc_ft_xly', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
combined_estimator.train(input_fn=_input_fn, steps=100)
combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpc_ft_xly/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmpc_ft_xly/model.ckpt-0.meta INFO:tensorflow:400 INFO:tensorflow:/tmpfs/tmp/tmpc_ft_xly/model.ckpt-0.index INFO:tensorflow:400 INFO:tensorflow:/tmpfs/tmp/tmpc_ft_xly/model.ckpt-0.data-00000-of-00001 INFO:tensorflow:400 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.4982232, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20... INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmpc_ft_xly/model.ckpt. INFO:tensorflow:/tmpfs/tmp/tmpc_ft_xly/model.ckpt-20.index INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmpc_ft_xly/model.ckpt-20.data-00000-of-00001 INFO:tensorflow:0 INFO:tensorflow:/tmpfs/tmp/tmpc_ft_xly/model.ckpt-20.meta INFO:tensorflow:400 INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20... INFO:tensorflow:Loss for final step: 0.5787579. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-06-08T01:32:51 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpc_ft_xly/model.ckpt-20 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Inference Time : 0.53733s INFO:tensorflow:Finished evaluation at 2022-06-08-01:32:51 INFO:tensorflow:Saving dict for global step 20: accuracy = 0.7083333, accuracy_baseline = 0.625, auc = 0.752158, auc_precision_recall = 0.6428529, average_loss = 0.5928443, global_step = 20, label/mean = 0.375, loss = 0.5804225, precision = 0.6666667, prediction/mean = 0.36947888, recall = 0.44444445 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmpc_ft_xly/model.ckpt-20 {'accuracy': 0.7083333, 'accuracy_baseline': 0.625, 'auc': 0.752158, 'auc_precision_recall': 0.6428529, 'average_loss': 0.5928443, 'label/mean': 0.375, 'loss': 0.5804225, 'precision': 0.6666667, 'prediction/mean': 0.36947888, 'recall': 0.44444445, 'global_step': 20}
TF2: Using Keras WideDeepModel
In TensorFlow 2, you can create an instance of the Keras tf.compat.v1.keras.models.WideDeepModel
to substitute for one generated by tf.estimator.DNNLinearCombinedEstimator
, with similar levels of user-specified customization (for instance, as in the previous example, the ability to customize a chosen model optimizer).
This WideDeepModel
is constructed on the basis of a constituent LinearModel
and a custom DNN Model, both of which are discussed in the preceding two examples. A custom linear model can also be used in place of the built-in Keras LinearModel
if desired.
If you would like to build your own model instead of a canned estimator, check out how to build a keras.Sequential
model. For more information on custom training and optimizers you can also checkout this guide.
# Create LinearModel and DNN Model as in Examples 1 and 2
optimizer = create_sample_optimizer('tf2')
linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10, verbose=0)
dnn_model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model,
dnn_model)
combined_model.compile(
optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy'])
combined_model.fit([x_train, x_train], y_train, epochs=10)
combined_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10 20/20 [==============================] - 0s 2ms/step - loss: 609.6476 - accuracy: 0.5694 Epoch 2/10 20/20 [==============================] - 0s 2ms/step - loss: 0.6053 - accuracy: 0.6188 Epoch 3/10 20/20 [==============================] - 0s 2ms/step - loss: 0.3424 - accuracy: 0.6699 Epoch 4/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2523 - accuracy: 0.6794 Epoch 5/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1985 - accuracy: 0.7018 Epoch 6/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2015 - accuracy: 0.7225 Epoch 7/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1756 - accuracy: 0.7305 Epoch 8/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1663 - accuracy: 0.7544 Epoch 9/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1603 - accuracy: 0.7640 Epoch 10/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1645 - accuracy: 0.7879 9/9 [==============================] - 0s 2ms/step - loss: 0.1766 - accuracy: 0.7197 {'loss': 0.17664867639541626, 'accuracy': 0.7196969985961914}
Example 4: Migrating from BoostedTreesEstimator
TF1: Using BoostedTreesEstimator
In TensorFlow 1, you could use tf.estimator.BoostedTreesEstimator
to create a baseline to create a baseline Gradient Boosting model using an ensemble of decision trees for regression and classification problems. This functionality is no longer included in TensorFlow 2.
bt_estimator = tf1.estimator.BoostedTreesEstimator(
head=tf.estimator.BinaryClassHead(),
n_batches_per_layer=1,
max_depth=10,
n_trees=1000,
feature_columns=feature_columns)
bt_estimator.train(input_fn=_input_fn, steps=1000)
bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)
TF2: Using TensorFlow Decision Forests
In TensorFlow 2, tf.estimator.BoostedTreesEstimator
is replaced by tfdf.keras.GradientBoostedTreesModel from the TensorFlow Decision Forests package.
TensorFlow Decision Forests provides various advantages over the tf.estimator.BoostedTreesEstimator
, notably regarding quality, speed, ease of use and flexibility. To learn about TensorFlow Decision Forests, start with the beginner colab.
The following example shows how to train a Gradient Boosted Trees model using TensorFlow 2:
Install TensorFlow Decision Forests.
pip install tensorflow_decision_forests
Create a TensorFlow dataset. Note that Decision Forests support natively many types of features and do not need pre-processing.
train_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
eval_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
# Convert the Pandas Dataframes into TensorFlow datasets.
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_dataframe, label="survived")
eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_dataframe, label="survived")
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:2542: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only. features_dataframe = dataframe.drop(label, 1) /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_decision_forests/keras/core.py:2542: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only. features_dataframe = dataframe.drop(label, 1)
Train the model on the train_dataset
dataset.
# Use the default hyper-parameters of the model.
gbt_model = tfdf.keras.GradientBoostedTreesModel()
gbt_model.fit(train_dataset)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. Use /tmpfs/tmp/tmpgg84bam8 as temporary training directory Reading training dataset... Training dataset read in 0:00:02.959102. Found 627 examples. Training model... Model trained in 0:00:00.197456 Compiling model... [INFO kernel.cc:1176] Loading model from path /tmpfs/tmp/tmpgg84bam8/model/ with prefix 20efacc67dbb4ba5 [INFO abstract_model.cc:1246] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO kernel.cc:1022] Use fast generic engine WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fbddbc85790> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fbddbc85790> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fbddbc85790> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert Model compiled. <keras.callbacks.History at 0x7fbd04072a00>
Evaluate the quality of the model on the eval_dataset
dataset.
gbt_model.compile(metrics=['accuracy'])
gbt_evaluation = gbt_model.evaluate(eval_dataset, return_dict=True)
print(gbt_evaluation)
1/1 [==============================] - 0s 268ms/step - loss: 0.0000e+00 - accuracy: 0.8295 {'loss': 0.0, 'accuracy': 0.8295454382896423}
Gradient Boosted Trees is just one of the many decision forests algorithms avaiable in TensorFlow Decision Forests. For example, Random Forests (available as tfdf.keras.GradientBoostedTreesModel is very resistant to overfitting) while CART (available as tfdf.keras.CartModel) is great for model interpretation.
In the next example, we train and plot a Random Forest model.
# Train a Random Forest model
rf_model = tfdf.keras.RandomForestModel()
rf_model.fit(train_dataset)
# Evaluate the Random Forest model
rf_model.compile(metrics=['accuracy'])
rf_evaluation = rf_model.evaluate(eval_dataset, return_dict=True)
print(rf_evaluation)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. Use /tmpfs/tmp/tmpbvdmlbv0 as temporary training directory Reading training dataset... Training dataset read in 0:00:00.168655. Found 627 examples. Training model... Model trained in 0:00:00.157815 Compiling model... [INFO kernel.cc:1176] Loading model from path /tmpfs/tmp/tmpbvdmlbv0/model/ with prefix 9dee2fe9a8754340 [INFO kernel.cc:1022] Use fast generic engine Model compiled. 1/1 [==============================] - 0s 117ms/step - loss: 0.0000e+00 - accuracy: 0.8333 {'loss': 0.0, 'accuracy': 0.8333333134651184}
Finally, in the next example, we train and evaluate a CART model.
# Train a CART model
cart_model = tfdf.keras.CartModel()
cart_model.fit(train_dataset)
# Plot the CART model
tfdf.model_plotter.plot_model_in_colab(cart_model, max_depth=2)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. Use /tmpfs/tmp/tmp9t2wkic1 as temporary training directory Reading training dataset... Training dataset read in 0:00:00.284037. Found 627 examples. Training model... Model trained in 0:00:00.008466 Compiling model... Model compiled. [INFO kernel.cc:1176] Loading model from path /tmpfs/tmp/tmp9t2wkic1/model/ with prefix 011284f1818c47a2 [INFO kernel.cc:1022] Use fast generic engine