Given graph, a directory to write outputs to (output_dir), and some ops,
run a training loop. The given train_op performs one step of training on the
model. The loss_op represents the objective function of the training. It is
expected to increment the global_step_tensor, a scalar integer tensor
counting training steps. This function uses Supervisor to initialize the
graph (from a checkpoint if one is available in output_dir), write summaries
defined in the graph, and write regular checkpoints as defined by
supervisor_save_model_secs.
Training continues until global_step_tensor evaluates to max_steps, or, if
fail_on_nan_loss, until loss_op evaluates to NaN. In that case the
program is terminated with exit code 1.
Args
graph
A graph to train. It is expected that this graph is not in use
elsewhere.
output_dir
A directory to write outputs to.
train_op
An op that performs one training step when run.
loss_op
A scalar loss tensor.
global_step_tensor
A tensor representing the global step. If none is given,
one is extracted from the graph using the same logic as in Supervisor.
init_op
An op that initializes the graph. If None, use Supervisor's
default.
init_feed_dict
A dictionary that maps Tensor objects to feed values.
This feed dictionary will be used when init_op is evaluated.
init_fn
Optional callable passed to Supervisor to initialize the model.
log_every_steps
Output logs regularly. The logs contain timing data and the
current loss.
supervisor_is_chief
Whether the current process is the chief supervisor in
charge of restoring the model and running standard services.
supervisor_master
The master string to use when preparing the session.
supervisor_save_model_secs
Save a checkpoint every
supervisor_save_model_secs seconds when training.
keep_checkpoint_max
The maximum number of recent checkpoint files to
keep. As new files are created, older files are deleted. If None or 0,
all checkpoint files are kept. This is simply passed as the max_to_keep
arg to tf.compat.v1.train.Saver constructor.
supervisor_save_summaries_steps
Save summaries every
supervisor_save_summaries_steps seconds when training.
feed_fn
A function that is called every iteration to produce a feed_dict
passed to session.run calls. Optional.
steps
Trains for this many steps (e.g. current global step + steps).
fail_on_nan_loss
If true, raise NanLossDuringTrainingError if loss_op
evaluates to NaN. If false, continue training as if nothing happened.
monitors
List of BaseMonitor subclass instances. Used for callbacks
inside the training loop.
max_steps
Number of total steps for which to train model. If None,
train forever. Two calls fit(steps=100) means 200 training iterations.
On the other hand two calls of fit(max_steps=100) means, second call
will not do any iteration since first call did all 100 steps.
Returns
The final loss value.
Raises
ValueError
If output_dir, train_op, loss_op, or global_step_tensor
is not provided. See tf.contrib.framework.get_global_step for how we
look up the latter if not provided explicitly.
NanLossDuringTrainingError
If fail_on_nan_loss is True, and loss ever
evaluates to NaN.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.learn.train\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/learn/python/learn/graph_actions.py#L128-L235) |\n\nTrain a model. (deprecated) \n\n tf.contrib.learn.train(\n graph, output_dir, train_op, loss_op, global_step_tensor=None, init_op=None,\n init_feed_dict=None, init_fn=None, log_every_steps=10, supervisor_is_chief=True,\n supervisor_master='', supervisor_save_model_secs=600, keep_checkpoint_max=5,\n supervisor_save_summaries_steps=100, feed_fn=None, steps=None,\n fail_on_nan_loss=True, monitors=None, max_steps=None\n )\n\n| **Warning:** THIS FUNCTION IS DEPRECATED. It will be removed after 2017-02-15. Instructions for updating: graph_actions.py will be deleted. Use tf.train.\\* utilities instead. You can use learn/estimators/estimator.py as an example.\n\nGiven `graph`, a directory to write outputs to (`output_dir`), and some ops,\nrun a training loop. The given `train_op` performs one step of training on the\nmodel. The `loss_op` represents the objective function of the training. It is\nexpected to increment the `global_step_tensor`, a scalar integer tensor\ncounting training steps. This function uses `Supervisor` to initialize the\ngraph (from a checkpoint if one is available in `output_dir`), write summaries\ndefined in the graph, and write regular checkpoints as defined by\n`supervisor_save_model_secs`.\n\nTraining continues until `global_step_tensor` evaluates to `max_steps`, or, if\n`fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the\nprogram is terminated with exit code 1.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `graph` | A graph to train. It is expected that this graph is not in use elsewhere. |\n| `output_dir` | A directory to write outputs to. |\n| `train_op` | An op that performs one training step when run. |\n| `loss_op` | A scalar loss tensor. |\n| `global_step_tensor` | A tensor representing the global step. If none is given, one is extracted from the graph using the same logic as in `Supervisor`. |\n| `init_op` | An op that initializes the graph. If `None`, use `Supervisor`'s default. |\n| `init_feed_dict` | A dictionary that maps `Tensor` objects to feed values. This feed dictionary will be used when `init_op` is evaluated. |\n| `init_fn` | Optional callable passed to Supervisor to initialize the model. |\n| `log_every_steps` | Output logs regularly. The logs contain timing data and the current loss. |\n| `supervisor_is_chief` | Whether the current process is the chief supervisor in charge of restoring the model and running standard services. |\n| `supervisor_master` | The master string to use when preparing the session. |\n| `supervisor_save_model_secs` | Save a checkpoint every `supervisor_save_model_secs` seconds when training. |\n| `keep_checkpoint_max` | The maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. This is simply passed as the max_to_keep arg to tf.compat.v1.train.Saver constructor. |\n| `supervisor_save_summaries_steps` | Save summaries every `supervisor_save_summaries_steps` seconds when training. |\n| `feed_fn` | A function that is called every iteration to produce a `feed_dict` passed to `session.run` calls. Optional. |\n| `steps` | Trains for this many steps (e.g. current global step + `steps`). |\n| `fail_on_nan_loss` | If true, raise `NanLossDuringTrainingError` if `loss_op` evaluates to `NaN`. If false, continue training as if nothing happened. |\n| `monitors` | List of `BaseMonitor` subclass instances. Used for callbacks inside the training loop. |\n| `max_steps` | Number of total steps for which to train model. If `None`, train forever. Two calls fit(steps=100) means 200 training iterations. On the other hand two calls of fit(max_steps=100) means, second call will not do any iteration since first call did all 100 steps. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| The final loss value. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `ValueError` | If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor` is not provided. See [`tf.contrib.framework.get_global_step`](../../../tf/contrib/framework/get_global_step) for how we look up the latter if not provided explicitly. |\n| `NanLossDuringTrainingError` | If `fail_on_nan_loss` is `True`, and loss ever evaluates to `NaN`. |\n| `ValueError` | If both `steps` and `max_steps` are not `None`. |\n\n\u003cbr /\u003e"]]