在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
TensorFlow 在 tf.random
模块中提供了一组伪随机数生成器 (RNG)。本文介绍如何控制随机数生成器,以及这些生成器如何与其他 Tensorflow 子系统交互。
注:不保证随机数在不同 TensorFlow 版本间一致。请参阅:版本兼容性
TensorFlow provides two approaches for controlling the random number generation process:
通过明确使用
tf.random.Generator
对象。每个此类对象都会在tf.Variable
中维护一个状态,该状态在每次生成随机数后都会发生改变。通过使用纯函数式无状态随机函数,如
tf.random.stateless_uniform
。在同一设备上调用具有相同参数(包括种子)的这些函数会产生相同的结果。
警告:目前尚未弃用 TF 1.x 中的旧版 RNG(如 tf.random.uniform
和 tf.random.normal
),但强烈建议不要使用。
设置
import tensorflow as tf
# Creates some virtual devices (cpu:0, cpu:1, etc.) for using distribution strategy
physical_devices = tf.config.list_physical_devices("CPU")
tf.config.experimental.set_virtual_device_configuration(
physical_devices[0], [
tf.config.experimental.VirtualDeviceConfiguration(),
tf.config.experimental.VirtualDeviceConfiguration(),
tf.config.experimental.VirtualDeviceConfiguration()
])
2022-12-14 22:35:47.032543: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:35:47.032638: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:35:47.032649: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
tf.random.Generator
类
当您希望每次调用 RNG 都产生不同的结果时,可以使用 tf.random.Generator
类。它会维护一个内部状态(由 tf.Variable
对象管理),该状态在每次生成随机数时都会更新。由于该状态由 tf.Variable
管理,因此,它可以利用 tf.Variable
提供的所有功能,如简单的检查点、自动控制依赖项和线程安全性。
通过手动创建 tf.random.Generator
类的一个对象,您可以获得该生成器,或者通过调用 tf.random.get_global_generator()
,您可以获得默认全局生成器:
g1 = tf.random.Generator.from_seed(1)
print(g1.normal(shape=[2, 3]))
g2 = tf.random.get_global_generator()
print(g2.normal(shape=[2, 3]))
tf.Tensor( [[ 0.43842277 -0.53439844 -0.07710262] [ 1.5658046 -0.1012345 -0.2744976 ]], shape=(2, 3), dtype=float32) tf.Tensor( [[-0.69198644 1.0939602 0.46467507] [ 0.72095203 0.6924698 -0.5659851 ]], shape=(2, 3), dtype=float32)
有多种方法可以创建生成器对象。最简单的方法是使用 Generator.from_seed
(代码如上),从种子创建生成器。种子可以是任何非负整数,from_seed
还有一个可选参数 alg
,这是该生成器将使用的 RNG 算法。
g1 = tf.random.Generator.from_seed(1, alg='philox')
print(g1.normal(shape=[2, 3]))
tf.Tensor( [[ 0.43842277 -0.53439844 -0.07710262] [ 1.5658046 -0.1012345 -0.2744976 ]], shape=(2, 3), dtype=float32)
有关详细信息,请参阅后文中的算法部分。
创建生成器的另一种方法是使用 Generator.from_non_deterministic_state
。以这种方式创建的生成器首先会处于非确定状态,具体取决于时间和操作系统等因素。
g = tf.random.Generator.from_non_deterministic_state()
print(g.normal(shape=[2, 3]))
tf.Tensor( [[-0.89688414 -1.27604 0.69840294] [-0.5044483 -0.09191426 -0.11111203]], shape=(2, 3), dtype=float32)
还有其他方法可以创建生成器,比如说通过显式状态创建,本指南不作赘述。
当使用 tf.random.get_global_generator
来获取全局生成器时,需要注意设备放置。第一次调用 tf.random.get_global_generator
时就会创建全局生成器(从非确定状态),并将其放置在该调用的作用域内的默认设备上。举个例子,如果第一次调用 tf.random.get_global_generator
的位置在 tf.device("gpu")
作用域内,则会将全局生成器放置在 GPU 上,如果稍后要从 CPU 使用全局生成器,则会将其从 GPU 复制到 CPU。
There is also a function tf.random.set_global_generator
for replacing the global generator with another generator object. This function should be used with caution though, because the old global generator may have been captured by a tf.function
(as a weak reference), and replacing it will cause it to be garbage collected, breaking the tf.function
. A better way to reset the global generator is to use one of the "reset" functions such as Generator.reset_from_seed
, which won't create new generator objects.
g = tf.random.Generator.from_seed(1)
print(g.normal([]))
print(g.normal([]))
g.reset_from_seed(1)
print(g.normal([]))
tf.Tensor(0.43842277, shape=(), dtype=float32) tf.Tensor(1.6272374, shape=(), dtype=float32) tf.Tensor(0.43842277, shape=(), dtype=float32)
创建独立的随机数流
许多应用都需要多个独立的随机数流,所谓独立,就是指不能相互重叠,也不能有统计学上可检测到的相关性。通过使用 Generator.split
创建多个一定相互独立的生成器即可实现此目的(即生成独立流)。
g = tf.random.Generator.from_seed(1)
print(g.normal([]))
new_gs = g.split(3)
for new_g in new_gs:
print(new_g.normal([]))
print(g.normal([]))
tf.Tensor(0.43842277, shape=(), dtype=float32) tf.Tensor(2.536413, shape=(), dtype=float32) tf.Tensor(0.33186463, shape=(), dtype=float32) tf.Tensor(-0.07144657, shape=(), dtype=float32) tf.Tensor(-0.79253083, shape=(), dtype=float32)
与 normal
之类的 RNG 方法类似,split
会改变调用它的生成器的状态(上例中为 g
)。除相互之间保持独立外,新生成器 (new_gs
) 还一定独立于旧生成器 (g
)。
当您想要确保使用的生成器位于与其他计算相同的设备上,从而避免跨设备复制的开销时,生成新生成器也很有用。例如:
with tf.device("cpu"): # change "cpu" to the device you want
g = tf.random.get_global_generator().split(1)[0]
print(g.normal([])) # use of g won't cause cross-device copy, unlike the global generator
tf.Tensor(-0.9215919, shape=(), dtype=float32)
注:在理论上,此处可以使用 from_seed
(而不是 split
)之类的构造函数获取新生成器,但这样做无法保证新生成器与全局生成器相互独立。同时也有使用同一种子或导致产生重叠随机数流的种子意外创建两个生成器的风险。
您可以在拆分生成器上调用 split
,执行递归拆分。递归深度没有限制(除非发生整数溢出)。
与 tf.function
交互
与 tf.function
一起使用时,tf.random.Generator
遵循与 tf.Variable
相同的原则。这包括三个方面:
在 tf.function
的外部创建生成器
tf.function
可以使用在其外部创建的生成器。
g = tf.random.Generator.from_seed(1)
@tf.function
def foo():
return g.normal([])
print(foo())
tf.Tensor(0.43842277, shape=(), dtype=float32)
调用该函数时,用户需要确保生成器对象仍处于活动状态(没有被回收)。
在 tf.function
的内部创建生成器
只有 tf.function
第一次运行时,才可以在其内部创建生成器。
g = None
@tf.function
def foo():
global g
if g is None:
g = tf.random.Generator.from_seed(1)
return g.normal([])
print(foo())
print(foo())
tf.Tensor(0.43842277, shape=(), dtype=float32) tf.Tensor(1.6272374, shape=(), dtype=float32)
将生成器作为参数传递给 tf.function
当用作 tf.function
的参数时,不同的生成器对象将导致 tf.function
的回溯。
num_traces = 0
@tf.function
def foo(g):
global num_traces
num_traces += 1
return g.normal([])
foo(tf.random.Generator.from_seed(1))
foo(tf.random.Generator.from_seed(2))
print(num_traces)
2
请注意,此回溯行为与 tf.Variable
一致:
num_traces = 0
@tf.function
def foo(v):
global num_traces
num_traces += 1
return v.read_value()
foo(tf.Variable(1))
foo(tf.Variable(2))
print(num_traces)
1
与分布策略交互
Generator
与分布策略有两种交互方式。
在分布策略的外部创建生成器
如果是在策略作用域的外部创建的生成器,则会序列化访问此生成器的所有副本,因此,每一个副本都会得到不同的随机数。
g = tf.random.Generator.from_seed(1)
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
def f():
print(g.normal([]))
results = strat.run(f)
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. tf.Tensor(0.43842274, shape=(), dtype=float32) tf.Tensor(1.6272374, shape=(), dtype=float32)
请注意,这种使用方法可能产生性能问题,因为生成器的设备与副本不同。
在分布策略的内部创建生成器
如果在策略作用域内创建生成器,则每个副本将获得不同且独立的随机数流。
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
g = tf.random.Generator.from_seed(1)
print(strat.run(lambda: g.normal([])))
print(strat.run(lambda: g.normal([])))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor(-0.87930447, shape=(), dtype=float32), 1: tf.Tensor(0.020661574, shape=(), dtype=float32) } WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) }
注:目前 tf.random.Generator
没有提供让不同副本获得相同(而非不同)流的选项(这在技术上并不难)。如果您有此功能的用例,请告知 TensorFlow 开发者。
如果生成器已植入种子(例如,由 Generator.from_seed
创建),则随机数由种子决定,即使不同的副本获得不同且不相关的数字也是如此。可以将在副本上生成的随机数看作副本 ID 的散列和对所有副本通用的“主要”随机数。因此,整个系统仍然是确定性的。
还可以在 Strategy.run
内创建 tf.random.Generator
:
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
def f():
g = tf.random.Generator.from_seed(1)
a = g.normal([])
b = g.normal([])
return tf.stack([a, b])
print(strat.run(f))
print(strat.run(f))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor([-0.87930447 -1.5822568 ], shape=(2,), dtype=float32), 1: tf.Tensor([0.02066157 0.77539235], shape=(2,), dtype=float32) } WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor([-0.87930447 -1.5822568 ], shape=(2,), dtype=float32), 1: tf.Tensor([0.02066157 0.77539235], shape=(2,), dtype=float32) }
我们不再建议将 tf.random.Generator
作为参数传递给 Strategy.run
,因为 Strategy.run
通常要求参数是张量,而不是生成器。
保存生成器
通常,为了保存或序列化,您可以按照处理 tf.Variable
或 tf.Module
(或其子类)的方式来处理 tf.random.Generator
。在 TF 中有两种序列化机制:检查点和 SavedModel。
检查点
可以使用 tf.train.Checkpoint
自由保存和恢复生成器。来自恢复点的随机数流将与来自保存点的随机数流相同。
filename = "./checkpoint"
g = tf.random.Generator.from_seed(1)
cp = tf.train.Checkpoint(generator=g)
print(g.normal([]))
tf.Tensor(0.43842277, shape=(), dtype=float32)
cp.write(filename)
print("RNG stream from saving point:")
print(g.normal([]))
print(g.normal([]))
RNG stream from saving point: tf.Tensor(1.6272374, shape=(), dtype=float32) tf.Tensor(1.6307176, shape=(), dtype=float32)
cp.restore(filename)
print("RNG stream from restoring point:")
print(g.normal([]))
print(g.normal([]))
RNG stream from restoring point: tf.Tensor(1.6272374, shape=(), dtype=float32) tf.Tensor(1.6307176, shape=(), dtype=float32)
您还可以在分发策略中保存和恢复:
filename = "./checkpoint"
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
g = tf.random.Generator.from_seed(1)
cp = tf.train.Checkpoint(my_generator=g)
print(strat.run(lambda: g.normal([])))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') PerReplica:{ 0: tf.Tensor(-0.87930447, shape=(), dtype=float32), 1: tf.Tensor(0.020661574, shape=(), dtype=float32) }
with strat.scope():
cp.write(filename)
print("RNG stream from saving point:")
print(strat.run(lambda: g.normal([])))
print(strat.run(lambda: g.normal([])))
RNG stream from saving point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32) }
with strat.scope():
cp.restore(filename)
print("RNG stream from restoring point:")
print(strat.run(lambda: g.normal([])))
print(strat.run(lambda: g.normal([])))
RNG stream from restoring point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32) }
在保存之前,应确保副本在其 RNG 调用历史记录中不会出现差异(例如,一个副本发出一个 RNG 调用,而另一个副本发出两个 RNG 调用)。否则,它们的内部 RNG 状态将会不同,tf.train.Checkpoint
(仅保存第一个副本的状态)将无法正确恢复所有副本。
您还可以使用不同数量的副本将保存的检查点恢复到不同的分发策略。由于在同一策略中创建的 tf.random.Generator
对象只能在同一策略中使用,因此要恢复到不同的策略,需要在目标策略中新建一个 tf.random.Generator
,并为其创建一个新的 tf.train.Checkpoint
,如下例所示:
filename = "./checkpoint"
strat1 = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat1.scope():
g1 = tf.random.Generator.from_seed(1)
cp1 = tf.train.Checkpoint(my_generator=g1)
print(strat1.run(lambda: g1.normal([])))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') PerReplica:{ 0: tf.Tensor(-0.87930447, shape=(), dtype=float32), 1: tf.Tensor(0.020661574, shape=(), dtype=float32) }
with strat1.scope():
cp1.write(filename)
print("RNG stream from saving point:")
print(strat1.run(lambda: g1.normal([])))
print(strat1.run(lambda: g1.normal([])))
RNG stream from saving point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32) }
strat2 = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1", "cpu:2"])
with strat2.scope():
g2 = tf.random.Generator.from_seed(1)
cp2 = tf.train.Checkpoint(my_generator=g2)
cp2.restore(filename)
print("RNG stream from restoring point:")
print(strat2.run(lambda: g2.normal([])))
print(strat2.run(lambda: g2.normal([])))
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1', '/job:localhost/replica:0/task:0/device:CPU:2') RNG stream from restoring point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32), 2: tf.Tensor(0.6851049, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32), 2: tf.Tensor(-0.58519536, shape=(), dtype=float32) }
虽然 g1
和 cp1
是与 g2
和 cp2
不同的对象,但它们通过公共检查点文件 filename
和对象名称 my_generator
进行了链接。策略之间重叠的副本(例如上面的 cpu:0
和 cpu:1
)将像前面的示例一样正确地恢复其 RNG 流。此保证不包括将生成器保存在策略作用域内并恢复到任何策略作用域之外的情况,反之亦然,因为策略之外的设备会被视为不同于策略中的任何副本。
SavedModel
可以将 tf.random.Generator
保存到 SavedModel。可以在策略作用域内创建生成器。保存也可以在策略作用域内进行。
filename = "./saved_model"
class MyModule(tf.Module):
def __init__(self):
super(MyModule, self).__init__()
self.g = tf.random.Generator.from_seed(0)
@tf.function
def __call__(self):
return self.g.normal([])
@tf.function
def state(self):
return self.g.state
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
m = MyModule()
print(strat.run(m))
print("state:", m.state())
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') PerReplica:{ 0: tf.Tensor(-1.4154755, shape=(), dtype=float32), 1: tf.Tensor(-0.11388441, shape=(), dtype=float32) } state: tf.Tensor([256 0 0], shape=(3,), dtype=int64)
with strat.scope():
tf.saved_model.save(m, filename)
print("RNG stream from saving point:")
print(strat.run(m))
print("state:", m.state())
print(strat.run(m))
print("state:", m.state())
INFO:tensorflow:Assets written to: ./saved_model/assets RNG stream from saving point: PerReplica:{ 0: tf.Tensor(-0.68758255, shape=(), dtype=float32), 1: tf.Tensor(0.8084062, shape=(), dtype=float32) } state: tf.Tensor([512 0 0], shape=(3,), dtype=int64) PerReplica:{ 0: tf.Tensor(-0.27342677, shape=(), dtype=float32), 1: tf.Tensor(-0.53093255, shape=(), dtype=float32) } state: tf.Tensor([768 0 0], shape=(3,), dtype=int64)
imported = tf.saved_model.load(filename)
print("RNG stream from loading point:")
print("state:", imported.state())
print(imported())
print("state:", imported.state())
print(imported())
print("state:", imported.state())
RNG stream from loading point: state: tf.Tensor([256 0 0], shape=(3,), dtype=int64) tf.Tensor(-1.0359411, shape=(), dtype=float32) state: tf.Tensor([512 0 0], shape=(3,), dtype=int64) tf.Tensor(-0.06425078, shape=(), dtype=float32) state: tf.Tensor([768 0 0], shape=(3,), dtype=int64)
不建议将包含 tf.random.Generator
的 SavedModel 加载到分发策略中,因为所有副本都将生成相同的随机数流(因为副本 ID 会在 SavedModel 的计算图中冻结)。
和上面的示例一样,将分布式 tf.random.Generator
(在分布策略中创建的生成器)加载到分布策略环境中也有一个注意事项。RNG 状态将被正确恢复,但生成的随机数将与其策略中的原始生成器不同(同样是因为策略之外的设备被视不同于策略中的任何副本)。
无状态 RNG
无状态 RNG 的使用方法非常简单。因为它们是纯函数,不涉及状态或副作用。
print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))
print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))
tf.Tensor( [[ 0.5441101 0.20738031 0.07356433] [ 0.04643455 -1.30159 -0.95385665]], shape=(2, 3), dtype=float32) tf.Tensor( [[ 0.5441101 0.20738031 0.07356433] [ 0.04643455 -1.30159 -0.95385665]], shape=(2, 3), dtype=float32)
每个无状态 RNG 都需要一个 seed
参数,该参数必须是形状为 [2]
的整数张量。该运算的结果完全由种子确定。
无状态 RNG 使用的 RNG 算法依赖于设备,这意味着在不同设备上运行的相同运算可能会产生不同的输出。
算法
基本信息
tf.random.Generator
类和 stateless
函数在所有设备上都支持 Philox 算法(写作 "philox"
或 tf.random.Algorithm.PHILOX
)。
如果使用相同的算法且从相同的状态开始,则不同的设备会生成相同的整数。它们还可以生成“几乎相同”的浮点数,虽然由于设备执行浮点计算的方式不同(如降阶),数值可能存在微小的差异。
XLA 设备
在 XLA 驱动的设备(如 TPU 以及启用 XLA 时的 CPU/GPU)上,还支持 ThreeFry 算法(写作 "threefry"
或 tf.random.Algorithm.THREEFRY
)。与 Philox 算法相比,该算法在 TPU 上执行速度较快,而在 CPU/GPU 上执行速度较慢。
有关这些算法的更多详细信息,请参阅论文“Parallel Random Numbers: As Easy as 1, 2, 3”。