Random number generation

View on TensorFlow.org View source on GitHub Download notebook

TensorFlow provides a set of pseudo-random number generators (RNG), in the tf.random module. This document describes how you can control the random number generators, and how these generators interact with other tensorflow sub-systems.

TensorFlow provides two approaches for controlling the random number generation process:

  1. Through the explicit use of tf.random.Generator objects. Each such object maintains a state (in tf.Variable) that will be changed after each number generation.

  2. Through the purely-functional stateless random functions like tf.random.stateless_uniform. Calling these functions with the same arguments (which include the seed) and on the same device will always produce the same results.

Setup

import tensorflow as tf

# Creates 2 virtual devices cpu:0 and cpu:1 for using distribution strategy
physical_devices = tf.config.experimental.list_physical_devices("CPU")
tf.config.experimental.set_virtual_device_configuration(
    physical_devices[0], [
        tf.config.experimental.VirtualDeviceConfiguration(),
        tf.config.experimental.VirtualDeviceConfiguration()
    ])

The tf.random.Generator class

The tf.random.Generator class is used in cases where you want each RNG call to produce different results. It maintains an internal state (managed by a tf.Variable object) which will be updated every time random numbers are generated. Because the state is managed by tf.Variable, it enjoys all facilities provided by tf.Variable such as easy checkpointing, automatic control-dependency and thread safety.

You can get a tf.random.Generator by manually creating an object of the class or call tf.random.get_global_generator() to get the default global generator:

g1 = tf.random.Generator.from_seed(1)
print(g1.normal(shape=[2, 3]))
g2 = tf.random.get_global_generator()
print(g2.normal(shape=[2, 3]))
tf.Tensor(
[[ 0.43842274 -0.53439844 -0.07710262]
 [ 1.5658046  -0.1012345  -0.2744976 ]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[-0.09818622  0.4978863   0.48663497]
 [ 0.12625825 -2.7745416  -1.709163  ]], shape=(2, 3), dtype=float32)

There are multiple ways to create a generator object. The easiest is Generator.from_seed, as shown above, that creates a generator from a seed. A seed is any non-negative integer. from_seed also takes an optional argument alg which is the RNG algorithm that will be used by this generator:

g1 = tf.random.Generator.from_seed(1, alg='philox')
print(g1.normal(shape=[2, 3]))
tf.Tensor(
[[ 0.43842274 -0.53439844 -0.07710262]
 [ 1.5658046  -0.1012345  -0.2744976 ]], shape=(2, 3), dtype=float32)

See the Algorithms section below for more information about it.

Another way to create a generator is with Generator.from_non_deterministic_state. A generator created this way will start from a non-deterministic state, depending on e.g. time and OS.

g = tf.random.Generator.from_non_deterministic_state()
print(g.normal(shape=[2, 3]))
tf.Tensor(
[[ 1.7364616  -0.3180744   1.4490141 ]
 [-0.90085554  0.03752621 -0.88233536]], shape=(2, 3), dtype=float32)

There are yet other ways to create generators, such as from explicit states, which are not covered by this guide.

When using tf.random.get_global_generator to get the global generator, you need to be careful about device placement. The global generator is created (from a non-deterministic state) at the first time tf.random.get_global_generator is called, and placed on the default device at that call. So, for example, if the first site you call tf.random.get_global_generator is within a tf.device("gpu") scope, the global generator will be placed on the GPU, and using the global generator later on from the CPU will incur a GPU-to-CPU copy.

There is also a function tf.random.set_global_generator for replacing the global generator with another generator object. This function should be used with caution thought, because the old global generator may have been captured by a tf.function (as a weak reference), and replacing it will cause it to be garbage collected, breaking the tf.function. A better way to reset the global generator is to use one of the "reset" functions such as Generator.reset_from_seed, which won't create new generator objects.

g = tf.random.Generator.from_seed(1)
print(g.normal([]))
print(g.normal([]))
g.reset_from_seed(1)
print(g.normal([]))
tf.Tensor(0.43842274, shape=(), dtype=float32)
tf.Tensor(1.6272374, shape=(), dtype=float32)
tf.Tensor(0.43842274, shape=(), dtype=float32)

Creating independent random-number streams

In many applications one needs multiple independent random-number streams, independent in the sense that they won't overlap and won't have any statistically detectable correlations. This is achieved by using Generator.split to create multiple generators that are guaranteed to be independent of each other (i.e. generating independent streams).

g = tf.random.Generator.from_seed(1)
print(g.normal([]))
new_gs = g.split(3)
for new_g in new_gs:
  print(new_g.normal([]))
print(g.normal([]))
tf.Tensor(0.43842274, shape=(), dtype=float32)
tf.Tensor(2.536413, shape=(), dtype=float32)
tf.Tensor(0.33186463, shape=(), dtype=float32)
tf.Tensor(-0.07144657, shape=(), dtype=float32)
tf.Tensor(-0.79253083, shape=(), dtype=float32)

split will change the state of the generator on which it is called (g in the above example), similar to an RNG method such as normal. In addition to being independent of each other, the new generators (new_gs) are also guaranteed to be independent of the old one (g).

Spawning new generators is also useful when you want to make sure the generator you use is on the same device as other computations, to avoid the overhead of cross-device copy. For example:

with tf.device("cpu"):  # change "cpu" to the device you want
  g = tf.random.get_global_generator().split(1)[0]  
  print(g.normal([]))  # use of g won't cause cross-device copy, unlike the global generator
tf.Tensor(-0.18633148, shape=(), dtype=float32)

You can do splitting recursively, calling split on splitted generators. There are no limits (barring integer overflow) on the depth of recursions.

Interaction with tf.function

tf.random.Generator obeys the same rules as tf.Variable when used with tf.function. This includes three aspects.

Creating generators outside tf.function

tf.function can use a generator created outside of it.

g = tf.random.Generator.from_seed(1)
@tf.function
def foo():
  return g.normal([])
print(foo())
tf.Tensor(0.43842274, shape=(), dtype=float32)

The user needs to make sure that the generator object is still alive (not garbage-collected) when the function is called.

Creating generators inside tf.function

Creation of generators inside a tf.function can only happend during the first run of the function.

g = None
@tf.function
def foo():
  global g
  if g is None:
    g = tf.random.Generator.from_seed(1)
  return g.normal([])
print(foo())
print(foo())
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
tf.Tensor(0.43842274, shape=(), dtype=float32)
tf.Tensor(1.6272374, shape=(), dtype=float32)

Passing generators as arguments to tf.function

When used as an argument to a tf.function, different generator objects with the same state size (state size is determined by the RNG algorithm) won't cause retracing of the tf.function, while those with different state sizes will.

num_traces = 0
@tf.function
def foo(g):
  global num_traces
  num_traces += 1
  return g.normal([])
foo(tf.random.Generator.from_seed(1))
foo(tf.random.Generator.from_seed(2))
print(num_traces)
1

Interaction with distribution strategies

There are three ways in which Generator interacts with distribution strategies.

Creating generators outside distribution strategies

If a generator is created outside strategy scopes, all replicas’ access to the generator will be serialized, and hence the replicas will get different random numbers.

g = tf.random.Generator.from_seed(1)
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
  def f():
    print(g.normal([]))
  results = strat.run(f)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
tf.Tensor(0.43842274, shape=(), dtype=float32)
tf.Tensor(1.6272374, shape=(), dtype=float32)

Note that this usage may have performance issues because the generator's device is different from the replicas.

Creating generators inside distribution strategies

Creating generators inside strategy scopes is disallowed, because there is ambiguity on how to replicate a generator (e.g. should it be copied so that each replica gets the same random numbers, or 'split' so that each replica gets different random numbers).

strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
  try:
    tf.random.Generator.from_seed(1)
  except ValueError as e:
    print("ValueError:", e)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')
ValueError: Creating a generator within a strategy scope is disallowed, because there is ambiguity on how to replicate a generator (e.g. should it be copied so that each replica gets the same random numbers, or 'split' so that each replica gets different random numbers).

Note that Strategy.run will run its argument function in a strategy scope implicitly:

strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
def f():
  tf.random.Generator.from_seed(1)
try:
  strat.run(f)
except ValueError as e:
  print("ValueError:", e)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
INFO:tensorflow:Error reported to Coordinator: Creating a generator within a strategy scope is disallowed, because there is ambiguity on how to replicate a generator (e.g. should it be copied so that each replica gets the same random numbers, or 'split' so that each replica gets different random numbers).
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
    yield
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 998, in run
    self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 282, in wrapper
    return func(*args, **kwargs)
  File "<ipython-input-14-2cd7806456bd>", line 3, in f
    tf.random.Generator.from_seed(1)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py", line 444, in from_seed
    return cls(state=state, alg=alg)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py", line 386, in __init__
    trainable=False)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/stateful_random_ops.py", line 272, in _create_variable
    "Creating a generator within a strategy scope is disallowed, because "
ValueError: Creating a generator within a strategy scope is disallowed, because there is ambiguity on how to replicate a generator (e.g. should it be copied so that each replica gets the same random numbers, or 'split' so that each replica gets different random numbers).
ValueError: Creating a generator within a strategy scope is disallowed, because there is ambiguity on how to replicate a generator (e.g. should it be copied so that each replica gets the same random numbers, or 'split' so that each replica gets different random numbers).

Passing generators as arguments to Strategy.run

If you want each replica to use its own generator, you need to make n generators (either by copying or splitting), where n is the number of replicas, and then pass them as arguments to Strategy.run.

strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
gs = tf.random.get_global_generator().split(2)
# to_args is a workaround for the absence of APIs to create arguments for 
# run. It will be replaced when such APIs are available.
def to_args(gs):  
  with strat.scope():
    def f():
      return [gs[tf.distribute.get_replica_context().replica_id_in_sync_group]]
    return strat.run(f)
args = to_args(gs)
def f(g):
  print(g.normal([]))
results = strat.run(f, args=args)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
tf.Tensor(1.2711695, shape=(), dtype=float32)
tf.Tensor(-0.57208955, shape=(), dtype=float32)

Stateless RNGs

Usage of stateless RNGs is simple. Since they are just pure functions, there is no state or side effect involved.

print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))
print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))
tf.Tensor(
[[ 0.5441101   0.20738031  0.07356433]
 [ 0.04643455 -1.3015898  -0.95385665]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[ 0.5441101   0.20738031  0.07356433]
 [ 0.04643455 -1.3015898  -0.95385665]], shape=(2, 3), dtype=float32)

Every stateless RNG requires a seed argument, which needs to be an integer Tensor of shape [2]. The results of the op are fully determined by this seed.

Algorithms

General

Both the tf.random.Generator class and the stateless functions support the Philox algorithm (written as "philox" or tf.random.Algorithm.PHILOX) on all devices.

Different devices will generate the same integer numbers, if using the same algorithm and starting from the same state. They will also generate "almost the same" float-point numbers, though there may be small numerical discrepancies caused by the different ways the devices carry out the float-point computation (e.g. reduction order).

XLA devices

On XLA-driven devices (such as TPU, and also CPU/GPU when XLA is enabled) the ThreeFry algorithm (written as "threefry" or tf.random.Algorithm.THREEFRY) is also supported. This algorithm is fast on TPU but slow on CPU/GPU compared to Philox.

See paper 'Parallel Random Numbers: As Easy as 1, 2, 3' for more details about these algorithms.