![]() |
A model that parameterizes forward pass by model weights.
tff.learning.models.FunctionalModel(
initial_weights: ModelWeights,
forward_pass_fn: Callable[[ModelWeights, Any, bool], model_lib.BatchOutput],
predict_on_batch_fn: Callable[[ModelWeights, Any, bool], Any],
input_spec
)
Args | |
---|---|
initial_weights
|
A 2-tuple (trainable, non_trainable) where the two
elements are sequences of weights. Weights must be values convertable to
tf.Tensor (e.g. numpy.ndarray , Python sequences, etc), but not
tf.Tensor values.
|
forward_pass_fn
|
A tf.function decorated callable that takes three
arguments, model_weights the same structure as initial_weights ,
batch_input a nested structure of tensors matching input_spec , and
training a boolean determinig whether the call is during a training
pass (e.g. for Dropout, BatchNormalization, etc).
|
predict_on_batch_fn
|
A tf.function decorated callable that takes three
arguments, model_weights the same structure as initial_weights , x
the first element of batch_input (or input_spec ), and training a
boolean determinig whether the call is during a training pass (e.g. for
Dropout, BatchNormalization, etc).
|
input_spec
|
A 2-tuple of (x, y) where each element is a nested structure
of tf.TensorSpec that defines the shape and dtypes of batch_input to
forward_pass_fn . x corresponds to batched model inputs and y
corresponds to batched labels for those inputs.
|
Attributes | |
---|---|
initial_weights
|
|
input_spec
|
Methods
forward_pass
@tf.function
forward_pass( model_weights: ModelWeights, batch_input: Any, training: bool = True ) ->
tff.learning.BatchOutput
Runs the forward pass and returns results.
predict_on_batch
@tf.function
predict_on_batch( model_weights: ModelWeights, x: Any, training: bool = True )
Returns tensor(s) interpretable by the loss function.