tf_agents.policies.py_policy.PyPolicy

Abstract base class for Python Policies.

The action(time_step, policy_state) method returns a PolicyStep named tuple containing the following nested arrays: action: The action to be applied on the environment. state: The state of the policy (E.g. RNN state) to be fed into the next call to action. info: Optional side information such as action log probabilities.

For stateful policies, e.g. those containing RNNs, an initial policy state can be obtained through a call to get_initial_state().

Example of simple use in Python:

py_env = PyEnvironment() policy = PyPolicy()

time_step = py_env.reset() policy_state = policy.get_initial_state()

acc_reward = 0 while not time_step.is_last(): action_step = policy.action(time_step, policy_state) policy_state = action_step.state time_step = py_env.step(action_step.action) acc_reward += time_step.reward

time_step_spec A TimeStep ArraySpec of the expected time_steps. Usually provided by the user to the subclass.
action_spec A nest of BoundedArraySpec representing the actions. Usually provided by the user to the subclass.
policy_state_spec A nest of ArraySpec representing the policy state. Provided by the subclass, not directly by the user.
info_spec A nest of ArraySpec representing the policy info. Provided by the subclass, not directly by the user.
observation_and_action_constraint_splitter A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] Note: when using observation_and_action_constraint_splitter, make sure the provided q_network is compatible with the network-specific half of the output of the observation_and_action_constraint_splitter. In particular, observation_and_action_constraint_splitter will be called on the observation before passing to the network. If observation_and_action_constraint_splitter is None, action constraints are not applied.

action_spec Describes the ArraySpecs of the np.Array returned by action().

action can be a single np.Array, or a nested dict, list or tuple of np.Array.

collect_data_spec Describes the data collected when using this policy with an environment.
info_spec Describes the Arrays emitted as info by action().
observation_and_action_constraint_splitter

policy_state_spec Describes the arrays expected by functions with policy_state as input.
policy_step_spec Describes the output of action().
time_step_spec Describes the TimeStep np.Arrays expected by action(time_step).
trajectory_spec Describes the data collected when using this policy with an environment.

Methods

action

View source

Generates next action given the time_step and policy_state.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state An optional previous policy_state.
seed Seed to use if action uses sampling (optional).

Returns
A PolicyStep named tuple containing: action: A nest of action Arrays matching the action_spec(). state: A nest of policy states to be fed into the next call to action. info: Optional side information such as action log probabilities.

get_initial_state

View source

Returns an initial state usable by the policy.

Args
batch_size An optional batch size.

Returns
An initial policy state.