Replay Buffers

View on Run in Google Colab View source on GitHub Download notebook


Reinforcement learning algorithms use replay buffers to store trajectories of experience when executing a policy in an environment. During training, replay buffers are queried for a subset of the trajectories (either a sequential subset or a sample) to "replay" the agent's experience.

In this colab, we explore two types of replay buffers: python-backed and tensorflow-backed, sharing a common API. In the following sections, we describe the API, each of the buffer implementations and how to use them during data collection training.


Install tf-agents if you haven't already.

pip install tf-agents
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from tf_agents import specs
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network
from tf_agents.replay_buffers import py_uniform_replay_buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step

Replay Buffer API

The Replay Buffer class has the following definition and methods:

class ReplayBuffer(tf.Module):
  """Abstract base class for TF-Agents replay buffer."""

  def __init__(self, data_spec, capacity):
    """Initializes the replay buffer.

      data_spec: A spec or a list/tuple/nest of specs describing
        a single item that can be stored in this buffer
      capacity: number of elements that the replay buffer can hold.

  def data_spec(self):
    """Returns the spec for items in the replay buffer."""

  def capacity(self):
    """Returns the capacity of the replay buffer."""

  def add_batch(self, items):
    """Adds a batch of items to the replay buffer."""

  def get_next(self,
    """Returns an item or batch of items from the buffer."""

  def as_dataset(self,
    """Creates and returns a dataset that returns entries from the buffer."""

  def gather_all(self):
    """Returns all the items in buffer."""
    return self._gather_all()

  def clear(self):
    """Resets the contents of replay buffer"""

Note that when the replay buffer object is initialized, it requires the data_spec of the elements that it will store. This spec corresponds to the TensorSpec of trajectory elements that will be added to the buffer. This spec is usually acquired by looking at an agent's agent.collect_data_spec which defines the shapes, types, and structures expected by the agent when training (more on that later).


TFUniformReplayBuffer is the most commonly used replay buffer in TF-Agents, thus we will use it in our tutorial here. In TFUniformReplayBuffer the backing buffer storage is done by tensorflow variables and thus is part of the compute graph.

The buffer stores batches of elements and has a maximum capacity max_length elements per batch segment. Thus, the total buffer capacity is batch_size x max_length elements. The elements stored in the buffer must all have a matching data spec. When the replay buffer is used for data collection, the spec is the agent's collect data spec.

Creating the buffer:

To create a TFUniformReplayBuffer we pass in:

  1. the spec of the data elements that the buffer will store
  2. the batch size corresponding to the batch size of the buffer
  3. the max_length number of elements per batch segment

Here is an example of creating a TFUniformReplayBuffer with sample data specs, batch_size 32 and max_length 1000.

data_spec =  (
        tf.TensorSpec([3], tf.float32, 'action'),
            tf.TensorSpec([5], tf.float32, 'lidar'),
            tf.TensorSpec([3, 2], tf.float32, 'camera')

batch_size = 32
max_length = 1000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(

Writing to the buffer:

To add elements to the replay buffer, we use the add_batch(items) method where items is a list/tuple/nest of tensors representing the batch of items to be added to the buffer. Each element of items must have an outer dimension equal batch_size and the remaining dimensions must adhere to the data spec of the item (same as the data specs passed to the replay buffer constructor).

Here's an example of adding a batch of items

action = tf.constant(1 * np.ones(
    data_spec[0].shape.as_list(), dtype=np.float32))
lidar = tf.constant(
    2 * np.ones(data_spec[1][0].shape.as_list(), dtype=np.float32))
camera = tf.constant(
    3 * np.ones(data_spec[1][1].shape.as_list(), dtype=np.float32))

values = (action, (lidar, camera))
values_batched = tf.nest.map_structure(lambda t: tf.stack([t] * batch_size),


Reading from the buffer

There are three ways to read data from the TFUniformReplayBuffer:

  1. get_next() - returns one sample from the buffer. The sample batch size and number of timesteps returned can be specified via arguments to this method.
  2. as_dataset() - returns the replay buffer as a One can then create a dataset iterator and iterate through the samples of the items in the buffer.
  3. gather_all() - returns all the items in the buffer as a Tensor with shape [batch, time, data_spec]

Below are examples of how to read from the replay buffer using each of these methods:

# add more items to the buffer before reading
for _ in range(5):

# Get one sample from the replay buffer with batch size 10 and 1 timestep:

sample = replay_buffer.get_next(sample_batch_size=10, num_steps=1)

# Convert the replay buffer to a and iterate through it
dataset = replay_buffer.as_dataset(

iterator = iter(dataset)
print("Iterator trajectories:")
trajectories = []
for _ in range(3):
  t, _ = next(iterator)

print(tf.nest.map_structure(lambda t: t.shape, trajectories))

# Read all elements in the replay buffer:
trajectories = replay_buffer.gather_all()

print("Trajectories from gather all:")
print(tf.nest.map_structure(lambda t: t.shape, trajectories))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_27300/ ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.
Iterator trajectories:
[(TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2])))]
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_27300/ ReplayBuffer.gather_all (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=True)` instead.
Trajectories from gather all:
(TensorShape([32, 6, 3]), (TensorShape([32, 6, 5]), TensorShape([32, 6, 3, 2])))


PyUniformReplayBuffer has the same functionaly as the TFUniformReplayBuffer but instead of tf variables, its data is stored in numpy arrays. This buffer can be used for out-of-graph data collection. Having the backing storage in numpy may make it easier for some applications to do data manipulation (such as indexing for updating priorities) without using Tensorflow variables. However, this implementation won't have the benefit of graph optimizations with Tensorflow.

Below is an example of instantiating a PyUniformReplayBuffer from the agent's policy trajectory specs:

replay_buffer_capacity = 1000*32 # same capacity as the TFUniformReplayBuffer

py_replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(

Using replay buffers during training

Now that we know how to create a replay buffer, write items to it and read from it, we can use it to store trajectories during training of our agents.

Data collection

First, let's look at how to use the replay buffer during data collection.

In TF-Agents we use a Driver (see the Driver tutorial for more details) to collect experience in an environment. To use a Driver, we specify an Observer that is a function for the Driver to execute when it receives a trajectory.

Thus, to add trajectory elements to the replay buffer, we add an observer that calls add_batch(items) to add a batch of items on the replay buffer.

Below is an example of this with TFUniformReplayBuffer. We first create an environment, a network and an agent. Then we create a TFUniformReplayBuffer. Note that the specs of the trajectory elements in the replay buffer are equal to the agent's collect data spec. We then set its add_batch method as the observer for the driver that will do the data collect during our training:

env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)

q_net = q_network.QNetwork(

agent = dqn_agent.DqnAgent(

replay_buffer_capacity = 1000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(

# Add an observer that adds to the replay buffer:
replay_observer = [replay_buffer.add_batch]

collect_steps_per_iteration = 10
collect_op = dynamic_step_driver.DynamicStepDriver(

Reading data for a train step

After adding trajectory elements to the replay buffer, we can read batches of trajectories from the replay buffer to use as input data for a train step.

Here is an example of how to train on trajectories from the replay buffer in a training loop:

# Read the replay buffer as a Dataset,
# read batches of 4 elements, each with 2 timesteps:
dataset = replay_buffer.as_dataset(

iterator = iter(dataset)

num_train_steps = 10

for _ in range(num_train_steps):
  trajectories, _ = next(iterator)
  loss = agent.train(experience=trajectories)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/ calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))