tf_agents.drivers.py_driver.PyDriver
Stay organized with collections
Save and categorize content based on your preferences.
A driver that runs a python policy in a python environment.
Inherits From: Driver
tf_agents.drivers.py_driver.PyDriver(
env: tf_agents.environments.PyEnvironment
,
policy: tf_agents.policies.py_policy.PyPolicy
,
observers: Sequence[Callable[[trajectory.Trajectory], Any]],
transition_observers: Optional[Sequence[Callable[[trajectory.Transition], Any]]] = None,
info_observers: Optional[Sequence[Callable[[Any], Any]]] = None,
max_steps: Optional[types.Int] = None,
max_episodes: Optional[types.Int] = None,
end_episode_on_boundary: bool = True
)
Used in the notebooks
Args |
env
|
A py_environment.Base environment.
|
policy
|
A py_policy.PyPolicy policy.
|
observers
|
A list of observers that are notified after every step in the
environment. Each observer is a callable(trajectory.Trajectory).
|
transition_observers
|
A list of observers that are updated after every
step in the environment. Each observer is a callable((TimeStep,
PolicyStep, NextTimeStep)). The transition is shaped just as
trajectories are for regular observers.
|
info_observers
|
A list of observers that are notified after every step in
the environment. Each observer is a callable(info).
|
max_steps
|
Optional maximum number of steps for each run() call. For
batched or parallel environments, this is the maximum total number of
steps summed across all environments. Also see below. Default: 0.
|
max_episodes
|
Optional maximum number of episodes for each run() call. For
batched or parallel environments, this is the maximum total number of
episodes summed across all environments. At least one of max_steps or
max_episodes must be provided. If both are set, run() terminates when at
least one of the conditions is satisfied. Default: 0.
|
end_episode_on_boundary
|
This parameter should be False when using
transition observers and be True when using trajectory observers.
|
Raises |
ValueError
|
If both max_steps and max_episodes are None.
|
Attributes |
env
|
|
info_observers
|
|
observers
|
|
policy
|
|
transition_observers
|
|
Methods
run
View source
run(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= ()
) -> Tuple[tf_agents.trajectories.TimeStep
, tf_agents.typing.types.NestedArray
]
Run policy in environment given initial time_step and policy_state.
Args |
time_step
|
The initial time_step.
|
policy_state
|
The initial policy_state.
|
Returns |
A tuple (final time_step, final policy_state).
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tf_agents.drivers.py_driver.PyDriver\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/drivers/py_driver.py#L33-L146) |\n\nA driver that runs a python policy in a python environment.\n\nInherits From: [`Driver`](../../../tf_agents/drivers/driver/Driver) \n\n tf_agents.drivers.py_driver.PyDriver(\n env: ../../../tf_agents/environments/PyEnvironment,\n policy: ../../../tf_agents/policies/py_policy/PyPolicy,\n observers: Sequence[Callable[[trajectory.Trajectory], Any]],\n transition_observers: Optional[Sequence[Callable[[trajectory.Transition], Any]]] = None,\n info_observers: Optional[Sequence[Callable[[Any], Any]]] = None,\n max_steps: Optional[types.Int] = None,\n max_episodes: Optional[types.Int] = None,\n end_episode_on_boundary: bool = True\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| - [Train a Deep Q Network with TF-Agents](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial) - [Drivers](https://www.tensorflow.org/agents/tutorials/4_drivers_tutorial) - [REINFORCE agent](https://www.tensorflow.org/agents/tutorials/6_reinforce_tutorial) |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `env` | A py_environment.Base environment. |\n| `policy` | A py_policy.PyPolicy policy. |\n| `observers` | A list of observers that are notified after every step in the environment. Each observer is a callable(trajectory.Trajectory). |\n| `transition_observers` | A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)). The transition is shaped just as trajectories are for regular observers. |\n| `info_observers` | A list of observers that are notified after every step in the environment. Each observer is a callable(info). |\n| `max_steps` | Optional maximum number of steps for each run() call. For batched or parallel environments, this is the maximum total number of steps summed across all environments. Also see below. Default: 0. |\n| `max_episodes` | Optional maximum number of episodes for each run() call. For batched or parallel environments, this is the maximum total number of episodes summed across all environments. At least one of max_steps or max_episodes must be provided. If both are set, run() terminates when at least one of the conditions is satisfied. Default: 0. |\n| `end_episode_on_boundary` | This parameter should be False when using transition observers and be True when using trajectory observers. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|----------------------------------------------|\n| `ValueError` | If both max_steps and max_episodes are None. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|------------------------|---------------|\n| `env` | \u003cbr /\u003e \u003cbr /\u003e |\n| `info_observers` | \u003cbr /\u003e \u003cbr /\u003e |\n| `observers` | \u003cbr /\u003e \u003cbr /\u003e |\n| `policy` | \u003cbr /\u003e \u003cbr /\u003e |\n| `transition_observers` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `run`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/drivers/py_driver.py#L100-L146) \n\n run(\n time_step: ../../../tf_agents/trajectories/TimeStep,\n policy_state: ../../../tf_agents/typing/types/NestedArray = ()\n ) -\u003e Tuple[../../../tf_agents/trajectories/TimeStep, ../../../tf_agents/typing/types/NestedArray]\n\nRun policy in environment given initial time_step and policy_state.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|----------------|---------------------------|\n| `time_step` | The initial time_step. |\n| `policy_state` | The initial policy_state. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A tuple (final time_step, final policy_state). ||\n\n\u003cbr /\u003e"]]