Exposes a Python policy as an in-graph TensorFlow policy.
Inherits From: TFPolicy
tf_agents.policies.tf_py_policy.TFPyPolicy(
policy: tf_agents.policies.py_policy.PyPolicy,
py_policy_is_batched: bool = False,
name: Optional[Text] = None
)
converting between TF and Py policies.
Args |
policy
|
Python policy implementing py_policy.PyPolicy.
|
py_policy_is_batched
|
If False, time_steps will be unbatched before
passing to py_policy.action(), and a batch dimension will be added to
the returned action. This will only work with time_steps that have a
batch dimension of 1. If True, the time_step (input) and action (output)
are passed exactly as is from/to the py_policy.
|
name
|
The name of this policy. All variables in this module will fall
under that name. Defaults to the class name.
|
Raises |
TypeError
|
if a non python policy is passed to constructor.
|
Attributes |
action_spec
|
Describes the TensorSpecs of the Tensors expected by step(action).
action can be a single Tensor, or a nested dict, list or tuple of
Tensors.
|
collect_data_spec
|
Describes the Tensors written when using this policy with an environment.
|
emit_log_probability
|
Whether this policy instance emits log probabilities or not.
|
info_spec
|
Describes the Tensors emitted as info by action and distribution.
info can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the Tensors expected by step(_, policy_state).
policy_state can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
|
policy_step_spec
|
Describes the output of action().
|
time_step_spec
|
Describes the TimeStep tensors returned by step().
|
trajectory_spec
|
Describes the Tensors written when using this policy with an environment.
|
validate_args
|
Whether action & distribution validate input and output args.
|
Methods
action
View source
action(
time_step: tf_agents.trajectories.TimeStep,
policy_state: tf_agents.typing.types.NestedTensor = (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
| Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec().
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
|
seed
|
Seed to use if action performs sampling (optional).
|
| Returns |
A PolicyStep named tuple containing:
action: An action Tensor matching the action_spec.
state: A policy state tensor to be fed into the next call to action.
info: Optional side information such as action log probabilities.
|
| Raises |
RuntimeError
|
If subclass init didn't call super().init.
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec, policy_state_spec,
or policy_step_spec.
|
distribution
View source
distribution(
time_step: tf_agents.trajectories.TimeStep,
policy_state: tf_agents.typing.types.NestedTensor = ()
) -> tf_agents.trajectories.PolicyStep
Generates the distribution over next actions given the time_step.
| Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec().
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
|
| Returns |
A PolicyStep named tuple containing:
action: A tf.distribution capturing the distribution of next actions.
state: A policy state tensor for the next call to distribution.
info: Optional side information such as action log probabilities.
|
| Raises |
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec, policy_state_spec,
or policy_step_spec.
|
get_initial_state
View source
get_initial_state(
batch_size: Optional[types.Int]
) -> tf_agents.typing.types.NestedTensor
Returns an initial state usable by the policy.
| Args |
batch_size
|
Tensor or constant: size of the batch dimension. Can be None
in which case no dimensions gets added.
|
| Returns |
A nested object of type policy_state containing properly
initialized Tensors.
|
update
View source
update(
policy,
tau: float = 1.0,
tau_non_trainable: Optional[float] = None,
sort_variables_by_name: bool = False
) -> tf.Operation
Update the current policy with another policy.
This would include copying the variables from the other policy.
| Args |
policy
|
Another policy it can update from.
|
tau
|
A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
update. This is used for trainable variables.
|
tau_non_trainable
|
A float scalar in [0, 1] for non_trainable variables.
If None, will copy from tau.
|
sort_variables_by_name
|
A bool, when True would sort the variables by name
before doing the update.
|
| Returns |
|
An TF op to do the update.
|