ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tff.learning.models.FunctionalModel

A model that parameterizes forward pass by model weights.

initial_weights A 2-tuple (trainable, non_trainable) where the two elements are sequences of weights. Weights must be values convertable to tf.Tensor (e.g. numpy.ndarray, Python sequences, etc), but not tf.Tensor values.
forward_pass_fn A tf.function decorated callable that takes three arguments, model_weights the same structure as initial_weights, batch_input a nested structure of tensors matching input_spec, and training a boolean determinig whether the call is during a training pass (e.g. for Dropout, BatchNormalization, etc).
predict_on_batch_fn A tf.function decorated callable that takes three arguments, model_weights the same structure as initial_weights, x the first element of batch_input (or input_spec), and training a boolean determinig whether the call is during a training pass (e.g. for Dropout, BatchNormalization, etc).
input_spec A 2-tuple of (x, y) where each element is a nested structure of tf.TensorSpec that defines the shape and dtypes of batch_input to forward_pass_fn. x corresponds to batched model inputs and y corresponds to batched labels for those inputs.

initial_weights

input_spec

Methods

forward_pass

View source

Runs the forward pass and returns results.

predict_on_batch

View source

Returns tensor(s) interpretable by the loss function.