此页面由 Cloud Translation API 翻译。
Switch to English

tf.Module

TensorFlow 1版 GitHub上查看源代码

基础神经网络模块类。

一个模块是一个名为容器tf.Variable S,其他tf.Module S和适用于用户输入的功能。例如,在神经网络的致密层可以实现为一个tf.Module

 class Dense(tf.Module):
   def __init__(self, in_features, out_features, name=None):
     super(Dense, self).__init__(name=name)
     self.w = tf.Variable(
       tf.random.normal([in_features, out_features]), name='w')
     self.b = tf.Variable(tf.zeros([out_features]), name='b')
   def __call__(self, x):
     y = tf.matmul(x, self.w) + self.b
     return tf.nn.relu(y)
  

您可以使用致密层,你会期望:

d = Dense(in_features=3, out_features=2)
d(tf.ones([1, 3]))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)>

通过继承tf.Module代替object任何tf.Variabletf.Module分配给对象的属性的实例可以使用被收集variablestrainable_variablessubmodules属性:

d.variables
    (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=...,
    dtype=float32)>,
    <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>)

的子类tf.Module还可以采取的优点_flatten可用于实现任何其它类型的跟踪方法。

所有tf.Module类具有关联tf.name_scope它可用于在TensorBoard组业务,并创建变量名可以与调试帮助层次。创建嵌套子模块/参数时,我们建议使用的名称范围或正向方法,其图形您可能希望在TensorBoard检查。你可以明确地使用输入名称范围with self.name_scope:或者你可以注释方法(除了__init__用) @tf.Module.with_name_scope

 class MLP(tf.Module):
  def __init__(self, input_size, sizes, name=None):
    super(MLP, self).__init__(name=name)
    self.layers = []
    with self.name_scope:
      for size in sizes:
        self.layers.append(Dense(input_size=input_size, output_size=size))
        input_size = size

  @tf.Module.with_name_scope
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    return x
 

name 返回通过或在构造函数确定该模块的名称。

name_scope 返回tf.name_scope此类实例。
submodules 所有子模块的顺序。

子模块是模块,其是该模块的属性,或者发现作为此模块(等)的特性的模块性能。

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

trainable_variables 可训练变量序列拥有此模块及其子模块。

variables 通过此模块及其子模块所拥有的变量序列。

方法

with_name_scope

查看源代码

装饰自动输入模块的名字范围。

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

使用上面的模块会产生tf.Variable S和tf.Tensor (胡)的名字包括模块名称:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

ARGS
method 该方法包。

返回
最初的方法缠绕,使得它进入模块的名称范围。