Trains models and prints debug info.
tfl.test_utils.run_training_loop(
config,
training_data,
keras_model,
input_dtype=np.float32,
label_dtype=np.float32
)
Args |
config
|
dictionary of test case parameters. See tests for TensorFlow Lattice
layers.
|
training_data
|
tuple: (training_inputs, labels) where
training_inputs and labels are proper data to train models passed via
other parameters.
|
keras_model
|
Keras model to train on training_data.
|
input_dtype
|
dtype for input conversion.
|
label_dtype
|
dtype for label conversion.
|
Returns |
Loss measured on training data and tf.session() if one was initialized
explicitly during training.
|