Base class for py_policy instances of TF policies in Eager mode.
Inherits From: PyPolicy
tf_agents.policies.py_tf_eager_policy.PyTFEagerPolicyBase(
policy: tf_agents.policies.TFPolicy
,
time_step_spec: tf_agents.trajectories.TimeStep
,
action_spec: tf_agents.typing.types.NestedArraySpec
,
policy_state_spec: tf_agents.typing.types.NestedArraySpec
,
info_spec: tf_agents.typing.types.NestedArraySpec
,
use_tf_function: bool = False,
batch_time_steps=True
)
Handles adding and removing batch dimensions from the actions and time_steps.
Note if you have a tf_policy you should directly use the PyTFEagerPolicy class
instead of this Base.
Args |
policy
|
tf_policy.TFPolicy instance to wrap and expose as a py_policy.
|
time_step_spec
|
A TimeStep ArraySpec of the expected time_steps. Usually
provided by the user to the subclass.
|
action_spec
|
A nest of BoundedArraySpec representing the actions. Usually
provided by the user to the subclass.
|
policy_state_spec
|
A nest of ArraySpec representing the policy state.
Provided by the subclass, not directly by the user.
|
info_spec
|
A nest of ArraySpec representing the policy info. Provided by
the subclass, not directly by the user.
|
use_tf_function
|
Wraps the use of policy.action in a tf.function call
which can help speed up execution.
|
batch_time_steps
|
Wether time_steps should be batched before being passed
to the wrapped policy. Leave as True unless you are dealing with a
batched environment, in which case you want to skip the batching as that
dim will already be present.
|
Attributes |
action_spec
|
Describes the ArraySpecs of the np.Array returned by action() .
action can be a single np.Array, or a nested dict, list or tuple of
np.Array.
|
collect_data_spec
|
Describes the data collected when using this policy with an environment.
|
info_spec
|
Describes the Arrays emitted as info by action() .
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the arrays expected by functions with policy_state as input.
|
policy_step_spec
|
Describes the output of action() .
|
time_step_spec
|
Describes the TimeStep np.Arrays expected by action(time_step) .
|
trajectory_spec
|
Describes the data collected when using this policy with an environment.
|
Methods
action
View source
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state.
|
seed
|
Seed to use if action uses sampling (optional).
|
Returns |
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
View source
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args |
batch_size
|
An optional batch size.
|
Returns |
An initial policy state.
|
variables
View source
variables()