Copyright 2023 The TF-Agents Authors.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Introduction
A common pattern in reinforcement learning is to execute a policy in an environment for a specified number of steps or episodes. This happens, for example, during data collection, evaluation and generating a video of the agent.
While this is relatively straightforward to write in python, it is much more complex to write and debug in TensorFlow because it involves tf.while
loops, tf.cond
and tf.control_dependencies
. Therefore we abstract this notion of a run loop into a class called driver
, and provide well tested implementations both in Python and TensorFlow.
Additionally, the data encountered by the driver at each step is saved in a named tuple called Trajectory and broadcast to a set of observers such as replay buffers and metrics. This data includes the observation from the environment, the action recommended by the policy, the reward obtained, the type of the current and the next step, etc.
Setup
If you haven't installed tf-agents or gym yet, run:
pip install tf-agents
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_py_policy
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver
Python Drivers
The PyDriver
class takes a python environment, a python policy and a list of observers to update at each step. The main method is run()
, which steps the environment using actions from the policy until at least one of the following termination criteria is met: The number of steps reaches max_steps
or the number of episodes reaches max_episodes
.
The implementation is roughly as follows:
class PyDriver(object):
def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):
self._env = env
self._policy = policy
self._observers = observers or []
self._max_steps = max_steps or np.inf
self._max_episodes = max_episodes or np.inf
def run(self, time_step, policy_state=()):
num_steps = 0
num_episodes = 0
while num_steps < self._max_steps and num_episodes < self._max_episodes:
# Compute an action using the policy for the given time_step
action_step = self._policy.action(time_step, policy_state)
# Apply the action to the environment and get the next step
next_time_step = self._env.step(action_step.action)
# Package information into a trajectory
traj = trajectory.Trajectory(
time_step.step_type,
time_step.observation,
action_step.action,
action_step.info,
next_time_step.step_type,
next_time_step.reward,
next_time_step.discount)
for observer in self._observers:
observer(traj)
# Update statistics to check termination
num_episodes += np.sum(traj.is_last())
num_steps += np.sum(~traj.is_boundary())
time_step = next_time_step
policy_state = action_step.state
return time_step, policy_state
Now, let us run through the example of running a random policy on the CartPole environment, saving the results to a replay buffer and computing some metrics.
env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(),
action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
env, policy, observers, max_steps=20, max_episodes=1)
initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)
print('Replay Buffer:')
for traj in replay_buffer:
print(traj)
print('Average Return: ', metric.result())
Replay Buffer: Trajectory( {'step_type': array(0, dtype=int32), 'observation': array([ 0.00374074, -0.02818722, -0.02798625, -0.0196638 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00317699, 0.16732468, -0.02837953, -0.3210437 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00652349, -0.02738187, -0.0348004 , -0.03744393], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00597585, -0.22198795, -0.03554928, 0.24405919], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([ 0.00153609, -0.41658458, -0.0306681 , 0.5253204 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.0067956 , -0.61126184, -0.02016169, 0.80818397], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.01902084, -0.8061018 , -0.00399801, 1.0944574 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.03514287, -0.6109274 , 0.01789114, 0.8005227 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.04736142, -0.8062901 , 0.03390159, 1.0987796 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.06348722, -0.61163044, 0.05587719, 0.816923 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.07571983, -0.41731614, 0.07221565, 0.54232585], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.08406615, -0.61337477, 0.08306216, 0.8568603 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.09633365, -0.8095243 , 0.10019937, 1.1744623 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.11252414, -0.6158369 , 0.12368862, 0.91479784], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.12484087, -0.8123951 , 0.14198457, 1.2436544 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.14108877, -0.61935145, 0.16685766, 0.9986062 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.1534758 , -0.42680538, 0.18682979, 0.7626272 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32)}) Trajectory( {'step_type': array(1, dtype=int32), 'observation': array([-0.1620119 , -0.23468053, 0.20208232, 0.5340639 ], dtype=float32), 'action': array(0), 'policy_info': (), 'next_step_type': array(2, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(0., dtype=float32)}) Trajectory( {'step_type': array(2, dtype=int32), 'observation': array([-0.16670552, -0.43198496, 0.21276361, 0.8830067 ], dtype=float32), 'action': array(1), 'policy_info': (), 'next_step_type': array(0, dtype=int32), 'reward': array(0., dtype=float32), 'discount': array(1., dtype=float32)}) Average Return: 18.0
TensorFlow Drivers
We also have drivers in TensorFlow which are functionally similar to Python drivers, but use TF environments, TF policies, TF observers etc. We currently have 2 TensorFlow drivers: DynamicStepDriver
, which terminates after a given number of (valid) environment steps and DynamicEpisodeDriver
, which terminates after a given number of episodes. Let us look at an example of the DynamicEpisode in action.
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
time_step_spec=tf_env.time_step_spec())
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env, tf_policy, observers, num_episodes=2)
# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep( {'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>, 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[-0.0367443 , 0.00652178, 0.04001181, -0.00376746]], dtype=float32)>}) Number of Steps: 34 Number of Episodes: 2
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep( {'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>, 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[-0.04702466, -0.04836502, 0.01751254, -0.00393545]], dtype=float32)>}) Number of Steps: 63 Number of Episodes: 4