Converts a tf.keras.Model
to a tff.learning.models.FunctionalModel
.
tff.learning.models.functional_model_from_keras(
keras_model: Union[tf.keras.Model, Callable[[], tf.keras.Model]],
loss_fn: tf.keras.losses.Loss,
input_spec: Union[Sequence[Any], Mapping[str, Any]],
metrics_constructor: Optional[Union[keras_utils.MetricConstructor, keras_utils.
MetricsConstructor, keras_utils.MetricConstructors]] = None
) -> tff.learning.models.FunctionalModel
Note: This method only supports models where calling that model with
training=True
and training=False
produce the same graph. Keras layers
such as batch normalization will fail because they require updating internal
state when training=True
which is not suported. Important: The returned model must only be used in a graph context (for
example inside a tff.tf_computation
decorated callable). It will raise an
error otherwise.
Raises
KerasFunctionalModelError
the model has a batch normalization layer.