View source on GitHub
|
Exposes a numpy API for saved_model policies in Eager mode.
Inherits From: PyTFEagerPolicyBase, PyPolicy
tf_agents.policies.SavedModelPyTFEagerPolicy(
model_path: Text,
time_step_spec: Optional[tf_agents.trajectories.TimeStep] = None,
action_spec: Optional[tf_agents.typing.types.DistributionSpecV2] = None,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec = (),
info_spec: tf_agents.typing.types.NestedTensorSpec = (),
load_specs_from_pbtxt: bool = False,
use_tf_function: bool = False,
batch_time_steps=True
)
Used in the notebooks
| Used in the tutorials |
|---|
Methods
action
action(
time_step: tf_agents.trajectories.TimeStep,
policy_state: tf_agents.typing.types.NestedArray = (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
| Args | |
|---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec().
|
policy_state
|
An optional previous policy_state. |
seed
|
Seed to use if action uses sampling (optional). |
| Returns | |
|---|---|
A PolicyStep named tuple containing:
action: A nest of action Arrays matching the action_spec().
state: A nest of policy states to be fed into the next call to action.
info: Optional side information such as action log probabilities.
|
get_initial_state
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
| Args | |
|---|---|
batch_size
|
An optional batch size. |
| Returns | |
|---|---|
| An initial policy state. |
get_metadata
get_metadata()
Returns the metadata of the saved model.
get_train_step
get_train_step() -> tf_agents.typing.types.Int
Returns the training global step of the saved model.
get_train_step_from_last_restored_checkpoint_path
get_train_step_from_last_restored_checkpoint_path() -> Optional[int]
Returns the training step of the restored checkpoint.
update_from_checkpoint
update_from_checkpoint(
checkpoint_path: Text
)
Allows users to update saved_model variables directly from a checkpoint.
checkpoint_path is a path that was passed to either PolicySaver.save()
or PolicySaver.save_checkpoint(). The policy looks for set of checkpoint
files with the file prefix `
| Args | |
|---|---|
checkpoint_path
|
Path to the checkpoint to restore and use to udpate this policy. |
variables
variables()
View source on GitHub