Exposes a Python policy as wrapper over a TF Policy.
Inherits From: PyPolicy
, SessionUser
tf_agents.policies.py_tf_policy.PyTFPolicy(
policy: tf_agents.policies.TFPolicy
,
batch_size: Optional[int] = None,
seed: Optional[types.Seed] = None
)
Args |
policy
|
A TF Policy implementing tf_policy.TFPolicy .
|
batch_size
|
(deprecated)
|
seed
|
Seed to use if policy performs random actions (optional).
|
Attributes |
action_spec
|
Describes the ArraySpecs of the np.Array returned by action() .
action can be a single np.Array, or a nested dict, list or tuple of
np.Array.
|
collect_data_spec
|
Describes the data collected when using this policy with an environment.
|
info_spec
|
Describes the Arrays emitted as info by action() .
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the arrays expected by functions with policy_state as input.
|
policy_step_spec
|
Describes the output of action() .
|
session
|
Returns the TensorFlow session-like object used by this object.
|
time_step_spec
|
Describes the TimeStep np.Arrays expected by action(time_step) .
|
trajectory_spec
|
Describes the data collected when using this policy with an environment.
|
Methods
action
View source
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state.
|
seed
|
Seed to use if action uses sampling (optional).
|
Returns |
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
View source
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args |
batch_size
|
An optional batch size.
|
Returns |
An initial policy state.
|
initialize
View source
initialize(
batch_size: Optional[int], graph: Optional[tf.Graph] = None
)
restore
View source
restore(
policy_dir: Text,
graph: Optional[tf.Graph] = None,
assert_consumed: bool = True
)
Restores the policy from the checkpoint.
Args |
policy_dir
|
Directory with the checkpoint.
|
graph
|
A graph, inside which policy the is restored (optional).
|
assert_consumed
|
If true, contents of the checkpoint will be checked for a
match against graph variables.
|
Returns |
step
|
Global step associated with the restored policy checkpoint.
|
Raises |
RuntimeError
|
if the policy is not initialized.
|
AssertionError
|
if the checkpoint contains variables which do not have
matching names in the graph, and assert_consumed is set to True.
|
save
View source
save(
policy_dir: Optional[Text] = None, graph: Optional[tf.Graph] = None
)