Converts a tf.keras.Model
to a tff.learning.models.FunctionalModel
.
tff.learning.models.functional_model_from_keras(
keras_model: Union[tf.keras.Model, Callable[[], tf.keras.Model]],
loss_fn: tf.keras.losses.Loss,
input_spec: Union[Sequence[Any], Mapping[str, Any]]
) -> tff.learning.models.FunctionalModel
Note: This method only supports models where calling that model with
training=True
and training=False
produce the same graph. Keras layers
such as batch normalization will fail because they require updating internal
state when training=True
which is not suported. Important: The returned model must only be used in a graph context (for
example inside a tff.tf_computation
decorated callable). It will raise an
error otherwise.
Args
keras_model
A tf.keras.Model
object, should be uncompiled. If compiled,
the metrics, optimizer, and loss function will be ignored. Note: models
that have multiple outputs will send all outputs to the loss_fn
.
loss_fn
A tf.keras.losses.Loss
object.
input_spec
A structure of tf.TensorSpec
defining the input to the model.
Raises
KerasFunctionalModelError
the model has a batch normalization layer.