View source on GitHub |
A PolicySaver
allows you to save a tf_policy.Policy
to SavedModel
.
tf_agents.policies.PolicySaver(
policy: tf_agents.policies.TFPolicy
,
batch_size: Optional[int] = None,
use_nest_path_signatures: bool = True,
seed: Optional[types.Seed] = None,
train_step: Optional[tf.Variable] = None,
input_fn_and_spec: Optional[tf_agents.policies.policy_saver.InputFnAndSpecType
] = None,
metadata: Optional[Dict[Text, tf.Variable]] = None
)
Used in the notebooks
Used in the tutorials |
---|
The save()
method exports a saved model to the requested export location.
The SavedModel that is exported can be loaded via
tf.compat.v2.saved_model.load
(or tf.saved_model.load
in TF2). The
following signatures (concrete functions) are available: action
,
get_initial_state
, and get_train_step
.
The attribute model_variables
is also available when the saved_model is
loaded which gives access to model variables in order to update them if
needed.
Usage:
my_policy = agent.collect_policy
saver = PolicySaver(my_policy, batch_size=None)
for i in range(...):
agent.train(...)
if i % 100 == 0:
saver.save('policy_%d' % global_step)
To load and use the saved policy directly:
saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=3)
time_step = ...
while True:
policy_step = saved_policy.action(time_step, policy_state)
policy_state = policy_step.state
time_step = f(policy_step.action)
...
or to use the distributional form, e.g.:
batch_size = 3
saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=batch_size)
time_step = ...
while True:
policy_step = saved_policy.distribution(time_step, policy_state)
policy_state = policy_step.state
time_step = f(policy_step.action.sample(batch_size))
...
If using the flattened (signature) version, you will be limited to using dicts keyed by the specs' name fields.
saved_policy = tf.compat.v2.saved_model.load('policy_0')
get_initial_state_fn = saved_policy.signatures['get_initial_state']
action_fn = saved_policy.signatures['action']
policy_state_dict = get_initial_state_fn(batch_size=3)
time_step_dict = ...
while True:
time_step_state = dict(time_step_dict)
time_step_state.update(policy_state_dict)
policy_step_dict = action_fn(time_step_state)
policy_state_dict = extract_policy_state_fields(policy_step_dict)
action_dict = extract_action_fields(policy_step_dict)
time_step_dict = f(action_dict)
...
Methods
get_metadata
get_metadata() -> Dict[Text, tf.Variable]
Returns the metadata of the policy.
Returns | |
---|---|
An a dictionary of tf.Variable. |
get_train_step
get_train_step() -> tf_agents.typing.types.Int
Returns the train step of the policy.
Returns | |
---|---|
An integer. |
register_concrete_function
register_concrete_function(
name: str, fn: def_function.Function, assets: Optional[Any] = None
) -> None
Registers a function into the saved model.
This gives you the flexibility to register any kind of polymorphic function by creating the concrete function that you wish to register.
Args | |
---|---|
name
|
Name of the attribute to use for the saved fn. |
fn
|
Function to register. Must be a callable following the input_spec as a single parameter. |
assets
|
Any extra checkpoint dependencies that must be captured in the module. Note variables are automatically captured. |
register_function
register_function(
name: str,
fn: tf_agents.policies.policy_saver.InputFnType
,
input_spec: tf_agents.typing.types.NestedTensorSpec
,
outer_dims: tf_agents.typing.types.ShapeSequence
= (None,)
) -> None
Registers a function into the saved model.
Args | |
---|---|
name
|
Name of the attribute to use for the saved fn. |
fn
|
Function to register. Must be a callable following the input_spec as a single parameter. |
input_spec
|
A nest of tf.TypeSpec representing the time_steps. Provided by the user. |
outer_dims
|
The outer dimensions the saved fn will process at a time. By default a batch dimension is added to the input_spec. |
save
save(
export_dir: Text, options: Optional[tf.saved_model.SaveOptions] = None
)
Save the policy to the given export_dir
.
Args | |
---|---|
export_dir
|
Directory to save the policy to. |
options
|
Optional tf.saved_model.SaveOptions object.
|
save_checkpoint
save_checkpoint(
export_dir: Text, options: Optional[tf.train.CheckpointOptions] = None
)
Saves the policy as a checkpoint to the given export_dir
.
This will only work with checkpoints generated in TF2.x.
For the checkpoint to be useful users should first call save
to generate a
saved_model of the policy. Checkpoints can then be used to update the policy
without having to reload the saved_model, or saving multiple copies of the
saved_model.pb
file.
The checkpoint is always created in the sub-directory 'variables/' and the checkpoint file prefix used is 'variables'. The checkpoint files are as follows:
- export_dir/variables/variables.index
- export_dir/variables/variables-xxxxx-of-xxxxx
This makes the files compatible with the checkpoint part of full saved models, which enables you to load a saved model made up from the graph part of a full saved model and the variables part of a checkpoint.
Args | |
---|---|
export_dir
|
Directory to save the checkpoint to. |
options
|
Optional tf.train.CheckpointOptions object.
|