tf_agents.train.triggers.PolicySavedModelTrigger
Stay organized with collections
Save and categorize content based on your preferences.
Triggers saves policy checkpoints an agent's policy.
Inherits From: IntervalTrigger
tf_agents.train.triggers.PolicySavedModelTrigger(
saved_model_dir: Text,
agent: tf_agents.agents.TFAgent
,
train_step: tf.Variable,
interval: int,
async_saving: bool = False,
metadata_metrics: Optional[Mapping[Text, py_metric.PyMetric]] = None,
start: int = 0,
extra_concrete_functions: Optional[Sequence[Tuple[str, policy_saver.def_function.Function]]] = None,
batch_size: Optional[int] = None,
use_nest_path_signatures: bool = True,
save_greedy_policy=True,
save_collect_policy=True,
input_fn_and_spec: Optional[tf_agents.policies.policy_saver.InputFnAndSpecType
] = None
)
Used in the notebooks
On construction this trigger will generate a saved_model for a:
greedy_policy
, a collect_policy
, and a raw_policy
. When triggered a
checkpoint will be saved which can be used to updated any of the saved_model
policies.
Args |
saved_model_dir
|
Base dir where checkpoints will be saved.
|
agent
|
Agent to extract policies from.
|
train_step
|
tf.Variable which keeps track of the number of train steps.
|
interval
|
How often, in train_steps, the trigger will save. Note that as
long as the >= interval number of steps have passed since the last
trigger, the event gets triggered. The current value is not necessarily
interval steps away from the last triggered value.
|
async_saving
|
If True saving will be done asynchronously in a separate
thread. Note if this is on the variable values in the saved
checkpoints/models are not deterministic.
|
metadata_metrics
|
A dictionary of metrics, whose result() method returns
a scalar to be saved along with the policy. Currently only supported
when async_saving is False.
|
start
|
Initial value for the trigger passed directly to the base class. It
helps control from which train step the weigts of the model are saved.
|
extra_concrete_functions
|
Optional sequence of extra concrete functions to
register in the policy savers. The sequence should consist of tuples
with string name for the function and the tf.function to register. Note
this does not support adding extra assets.
|
batch_size
|
The number of batch entries the policy will process at a time.
This must be either None (unknown batch size) or a python integer.
|
use_nest_path_signatures
|
SavedModel spec signatures will be created based
on the sructure of the specs. Otherwise all specs must have unique
names.
|
save_greedy_policy
|
Disable when an agent's policy distribution method
does not support mode.
|
save_collect_policy
|
Disable when not saving collect policy.
|
input_fn_and_spec
|
A (input_fn, tensor_spec) tuple where input_fn is a
function that takes inputs according to tensor_spec and converts them to
the (time_step, policy_state) tuple that is used as the input to the
action_fn. When input_fn_and_spec is set, tensor_spec is the input
for the action signature. When input_fn_and_spec is None , the action
signature takes as input (time_step, policy_state) .
|
Methods
reset
View source
reset() -> None
Resets the trigger interval.
set_start
View source
set_start(
start: int
) -> None
__call__
View source
__call__(
value: int, force_trigger: bool = False
) -> None
Maybe trigger the event based on the interval.
Args |
value
|
the value for triggering.
|
force_trigger
|
If True, the trigger will be forced triggered unless the
last trigger value is equal to value .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tf_agents.train.triggers.PolicySavedModelTrigger\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/train/triggers.py#L40-L198) |\n\nTriggers saves policy checkpoints an agent's policy.\n\nInherits From: [`IntervalTrigger`](../../../tf_agents/train/interval_trigger/IntervalTrigger) \n\n tf_agents.train.triggers.PolicySavedModelTrigger(\n saved_model_dir: Text,\n agent: ../../../tf_agents/agents/TFAgent,\n train_step: tf.Variable,\n interval: int,\n async_saving: bool = False,\n metadata_metrics: Optional[Mapping[Text, py_metric.PyMetric]] = None,\n start: int = 0,\n extra_concrete_functions: Optional[Sequence[Tuple[str, policy_saver.def_function.Function]]] = None,\n batch_size: Optional[int] = None,\n use_nest_path_signatures: bool = True,\n save_greedy_policy=True,\n save_collect_policy=True,\n input_fn_and_spec: Optional[../../../tf_agents/policies/policy_saver/InputFnAndSpecType] = None\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|------------------------------------------------------------------------------------------------------------------|\n| - [SAC minitaur with the Actor-Learner API](https://www.tensorflow.org/agents/tutorials/7_SAC_minitaur_tutorial) |\n\nOn construction this trigger will generate a saved_model for a:\n`greedy_policy`, a `collect_policy`, and a `raw_policy`. When triggered a\ncheckpoint will be saved which can be used to updated any of the saved_model\npolicies.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `saved_model_dir` | Base dir where checkpoints will be saved. |\n| `agent` | Agent to extract policies from. |\n| `train_step` | [`tf.Variable`](https://www.tensorflow.org/api_docs/python/tf/Variable) which keeps track of the number of train steps. |\n| `interval` | How often, in train_steps, the trigger will save. Note that as long as the \\\u003e= `interval` number of steps have passed since the last trigger, the event gets triggered. The current value is not necessarily `interval` steps away from the last triggered value. |\n| `async_saving` | If True saving will be done asynchronously in a separate thread. Note if this is on the variable values in the saved checkpoints/models are not deterministic. |\n| `metadata_metrics` | A dictionary of metrics, whose `result()` method returns a scalar to be saved along with the policy. Currently only supported when async_saving is False. |\n| `start` | Initial value for the trigger passed directly to the base class. It helps control from which train step the weigts of the model are saved. |\n| `extra_concrete_functions` | Optional sequence of extra concrete functions to register in the policy savers. The sequence should consist of tuples with string name for the function and the tf.function to register. Note this does not support adding extra assets. |\n| `batch_size` | The number of batch entries the policy will process at a time. This must be either `None` (unknown batch size) or a python integer. |\n| `use_nest_path_signatures` | SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names. |\n| `save_greedy_policy` | Disable when an agent's policy distribution method does not support mode. |\n| `save_collect_policy` | Disable when not saving collect policy. |\n| `input_fn_and_spec` | A `(input_fn, tensor_spec)` tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the `(time_step, policy_state)` tuple that is used as the input to the action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input for the action signature. When `input_fn_and_spec is None`, the action signature takes as input `(time_step, policy_state)`. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `reset`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/train/interval_trigger.py#L67-L69) \n\n reset() -\u003e None\n\nResets the trigger interval.\n\n### `set_start`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/train/interval_trigger.py#L71-L72) \n\n set_start(\n start: int\n ) -\u003e None\n\n### `__call__`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/train/interval_trigger.py#L50-L65) \n\n __call__(\n value: int, force_trigger: bool = False\n ) -\u003e None\n\nMaybe trigger the event based on the interval.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|-----------------|--------------------------------------------------------------------------------------------------|\n| `value` | the value for triggering. |\n| `force_trigger` | If True, the trigger will be forced triggered unless the last trigger value is equal to `value`. |\n\n\u003cbr /\u003e"]]