Base class for py_policy instances of TF policies in Eager mode.
Inherits From: PyPolicy
tf_agents.policies.py_tf_eager_policy.PyTFEagerPolicyBase(
policy: tf_agents.policies.TFPolicy
,
time_step_spec: tf_agents.trajectories.TimeStep
,
action_spec: tf_agents.typing.types.NestedArraySpec
,
policy_state_spec: tf_agents.typing.types.NestedArraySpec
,
info_spec: tf_agents.typing.types.NestedArraySpec
,
use_tf_function: bool = False,
batch_time_steps=True
)
Handles adding and removing batch dimensions from the actions and time_steps. Note if you have a tf_policy you should directly use the PyTFEagerPolicy class instead of this Base.
Args | |
---|---|
policy
|
tf_policy.TFPolicy instance to wrap and expose as a py_policy.
|
time_step_spec
|
A TimeStep ArraySpec of the expected time_steps. Usually
provided by the user to the subclass.
|
action_spec
|
A nest of BoundedArraySpec representing the actions. Usually provided by the user to the subclass. |
policy_state_spec
|
A nest of ArraySpec representing the policy state. Provided by the subclass, not directly by the user. |
info_spec
|
A nest of ArraySpec representing the policy info. Provided by the subclass, not directly by the user. |
use_tf_function
|
Wraps the use of policy.action in a tf.function call
which can help speed up execution.
|
batch_time_steps
|
Wether time_steps should be batched before being passed to the wrapped policy. Leave as True unless you are dealing with a batched environment, in which case you want to skip the batching as that dim will already be present. |
Methods
action
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state. |
seed
|
Seed to use if action uses sampling (optional). |
Returns | |
---|---|
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args | |
---|---|
batch_size
|
An optional batch size. |
Returns | |
---|---|
An initial policy state. |
variables
variables()