![]() |
![]() |
Constructs an Estimator
instance from given keras model.
Aliases:
tf.keras.estimator.model_to_estimator(
keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None,
checkpoint_format='saver'
)
For usage example, please see: Creating estimators from Keras Models.
Sample Weights
Estimators returned by model_to_estimator
are configured to handle sample
weights (similar to keras_model.fit(x, y, sample_weights)
). To pass sample
weights when training or evaluating the Estimator, the first item returned by
the input function should be a dictionary with keys features
and
sample_weights
. Example below:
keras_model = tf.keras.Model(...)
keras_model.compile(...)
estimator = tf.keras.estimator.model_to_estimator(keras_model)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features': features, 'sample_weights': sample_weights},
targets))
estimator.train(input_fn, steps=1)
Args:
keras_model
: A compiled Keras model object. This argument is mutually exclusive withkeras_model_path
.keras_model_path
: Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with thesave()
method of a Keras model. This argument is mutually exclusive withkeras_model
.custom_objects
: Dictionary for custom objects.model_dir
: Directory to saveEstimator
model parameters, graph, summary files for TensorBoard, etc.config
:RunConfig
to configEstimator
.checkpoint_format
: Sets the format of the checkpoint saved by the estimator when training. May besaver
orcheckpoint
, depending on whether to save checkpoints fromtf.train.Saver
ortf.train.Checkpoint
. This argument currently defaults tosaver
. When 2.0 is released, the default will becheckpoint
. Estimators use name-basedtf.train.Saver
checkpoints, while Keras models use object-based checkpoints fromtf.train.Checkpoint
. Currently, saving object-based checkpoints frommodel_to_estimator
is only supported by Functional and Sequential models.
Returns:
An Estimator from given keras model.
Raises:
ValueError
: if neither keras_model nor keras_model_path was given.ValueError
: if both keras_model and keras_model_path was given.ValueError
: if the keras_model_path is a GCS URI.ValueError
: if keras_model has not been compiled.ValueError
: if an invalid checkpoint_format was given.