tf_agents.policies.py_tf_policy.PyTFPolicy

Exposes a Python policy as wrapper over a TF Policy.

Inherits From: PyPolicy, SessionUser

policy A TF Policy implementing tf_policy.TFPolicy.
batch_size (deprecated)
seed Seed to use if policy performs random actions (optional).

action_spec Describes the ArraySpecs of the np.Array returned by action().

action can be a single np.Array, or a nested dict, list or tuple of np.Array.

collect_data_spec Describes the data collected when using this policy with an environment.
info_spec Describes the Arrays emitted as info by action().
observation_and_action_constraint_splitter

policy_state_spec Describes the arrays expected by functions with policy_state as input.
policy_step_spec Describes the output of action().
session Returns the TensorFlow session-like object used by this object.
time_step_spec Describes the TimeStep np.Arrays expected by action(time_step).
trajectory_spec Describes the data collected when using this policy with an environment.

Methods

action

View source

Generates next action given the time_step and policy_state.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state An optional previous policy_state.
seed Seed to use if action uses sampling (optional).

Returns
A PolicyStep named tuple containing: action: A nest of action Arrays matching the action_spec(). state: A nest of policy states to be fed into the next call to action. info: Optional side information such as action log probabilities.

get_initial_state

View source

Returns an initial state usable by the policy.

Args
batch_size An optional batch size.

Returns
An initial policy state.

initialize

View source

restore

View source

Restores the policy from the checkpoint.

Args
policy_dir Directory with the checkpoint.
graph A graph, inside which policy the is restored (optional).
assert_consumed If true, contents of the checkpoint will be checked for a match against graph variables.

Returns
step Global step associated with the restored policy checkpoint.

Raises
RuntimeError if the policy is not initialized.
AssertionError if the checkpoint contains variables which do not have matching names in the graph, and assert_consumed is set to True.

save

View source