![]() |
A PolicySaver
allows you to save a tf_policy.Policy
to SavedModel
.
tf_agents.policies.policy_saver.PolicySaver(
policy: tf_agents.policies.tf_policy.TFPolicy
,
batch_size: Optional[int] = None,
use_nest_path_signatures: bool = True,
seed: Optional[types.Seed] = None,
train_step: Optional[tf.Variable] = None,
input_fn_and_spec: Optional[tf_agents.policies.policy_saver.InputFnAndSpecType
] = None,
metadata: Optional[Dict[Text, tf.Variable]] = None
)
Used in the notebooks
Used in the tutorials |
---|
The save()
method exports a saved model to the requested export location.
The SavedModel that is exported can be loaded via
tf.compat.v2.saved_model.load
(or tf.saved_model.load
in TF2). It
will have available signatures (concrete functions): action
,
get_initial_state
, `get_train_step.
The attribute model_variables
is also available when the saved_model is
loaded which gives access to model variables in order to update them if
needed.
Usage:
my_policy = agent.collect_policy
saver = PolicySaver(my_policy, batch_size=None)
for i in range(...):
agent.train(...)
if i % 100 == 0:
saver.save('policy_%d' % global_step)
To load and use the saved policy directly:
saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=3)
time_step = ...
while True:
policy_step = saved_policy.action(time_step, policy_state)
policy_state = policy_step.state
time_step = f(policy_step.action)
...
or to use the distributional form, e.g.:
batch_size = 3
saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=batch_size)
time_step = ...
while True:
policy_step = saved_policy.distribution(time_step, policy_state)
policy_state = policy_step.state
time_step = f(policy_step.action.sample(batch_size))
...
If using the flattened (signature) version, you will be limited to using dicts keyed by the specs' name fields.
saved_policy = tf.compat.v2.saved_model.load('policy_0')
get_initial_state_fn = saved_policy.signatures['get_initial_state']
action_fn = saved_policy.signatures['action']
policy_state_dict = get_initial_state_fn(batch_size=3)
time_step_dict = ...
while True:
time_step_state = dict(time_step_dict)
time_step_state.update(policy_state_dict)
policy_step_dict = action_fn(time_step_state)
policy_state_dict = extract_policy_state_fields(policy_step_dict)
action_dict = extract_action_fields(policy_step_dict)
time_step_dict = f(action_dict)
...
Args | |
---|---|
policy
|
A TF Policy. |
batch_size
|
The number of batch entries the policy will process at a time.
This must be either None (unknown batch size) or a python integer.
|
use_nest_path_signatures
|
SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names. |
seed
|
Random seed for the policy.action call, if any (this should
usually be None , except for testing).
|
train_step
|
Variable holding the train step for the policy. The value
saved will be set at the time saver.save is called. If not provided,
train_step defaults to -1. Note since the train step must be a variable
it is not safe to create it directly in TF1 so in that case this is a
required parameter.
|
input_fn_and_spec
|
A (input_fn, tensor_spec) tuple where input_fn is a
function that takes inputs according to tensor_spec and converts them to
the (time_step, policy_state) tuple that is used as the input to the
action_fn. When input_fn_and_spec is set, tensor_spec is the input
for the action signature. When input_fn_and_spec is None , the action
signature takes as input (time_step, policy_state) .
|
metadata
|
A dictionary of tf.Variables to be saved along with the
policy.
|
Raises | |
---|---|
TypeError
|
If policy is not an instance of TFPolicy.
|
TypeError
|
If metadata is not a dictionary of tf.Variables.
|
ValueError
|
If use_nest_path_signatures is not used and any of the
following policy specs are missing names, or the names collide:
policy.time_step_spec , policy.action_spec ,
policy.policy_state_spec , policy.info_spec .
|
ValueError
|
If batch_size is not either None or a python integer > 0.
|
Attributes | |
---|---|
action_input_spec
|
Tuple (time_step_spec, policy_state_spec) for feeding action .
This describes the input of This may differ from the original policy if |
policy_state_spec
|
Spec that describes the output of get_initial_state in the SavedModel.
This may differ from the original policy if |
policy_step_spec
|
Spec that describes the output of action in the SavedModel.
This may differ from the original policy if |
signatures
|
Get the (flat) signatures used when exporting the SavedModel .
|
Methods
get_metadata
get_metadata() -> Dict[Text, tf.Variable]
Returns the metadata of the policy.
Returns | |
---|---|
An a dictionary of tf.Variable. |
get_train_step
get_train_step() -> tf_agents.typing.types.Int
Returns the train step of the policy.
Returns | |
---|---|
An integer. |
save
save(
export_dir: Text,
options: Optional[tf.saved_model.SaveOptions] = None
)
Save the policy to the given export_dir
.
Args | |
---|---|
export_dir
|
Directory to save the policy to. |
options
|
Optional tf.saved_model.SaveOptions object.
|
save_checkpoint
save_checkpoint(
export_dir: Text,
options: Optional[tf.train.CheckpointOptions] = None
)
Saves the policy as a checkpoint to the given export_dir
.
This will only work with checkpoints generated in TF2.x.
For the checkpoint to be useful users should first call save
to generate a
saved_model of the policy. Checkpoints can then be used to update the policy
without having to reload the saved_model, or saving multiple copies of the
saved_model.pb
file.
The checkpoint is always created in the sub-directory 'variables/' and the checkpoint file prefix used is 'variables'. The checkpoint files are as follows:
- export_dir/variables/variables.index
- export_dir/variables/variables-xxxxx-of-xxxxx
This makes the files compatible with the checkpoint part of full saved models, which enables you to load a saved model made up from the graph part of a full saved model and the variables part of a checkpoint.
Args | |
---|---|
export_dir
|
Directory to save the checkpoint to. |
options
|
Optional tf.train.CheckpointOptions object.
|