在 TF2 工作流中使用 TF1.x 模型

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本指南提供了建模代码 shim 的概述和示例,您可以使用这些模型在 TF2 工作流(例如 Eager Execution、tf.function 和分布策略)中使用现有 TF1.x 模型,只需对建模代码进行少量的更改。

使用范围

本指南中介绍的 shim 是为 TF1.x 模型设计的,它依赖于:

  1. tf.compat.v1.get_variabletf.compat.v1.variable_scope 来控制变量的创建和重用,以及
  2. tf.compat.v1.global_variables()tf.compat.v1.trainable_variablestf.compat.v1.losses.get_regularization_losses()tf.compat.v1.get_collection() 等基于计算图集合的 API 来跟踪权重和正则化损失

这包括大多数在 tf.compat.v1.layertf.contrib.layers API 和 TensorFlow-Slim 上构建的模型。

以下 TF1.x 模型需要 shim:

  1. 已经分别通过 model.trainable_weightsmodel.losses 跟踪所有可训练权重和正则化损失的独立 Keras 模型。
  2. 已经通过 module.trainable_variables 跟踪其所有可训练权重,并且仅在尚未创建时才创建权重的 tf.Module

这些模型很可能在 TF2 中使用 Eager Execution 和开箱即用的 tf.function

安装

导入 TensorFlow 和其他依赖项。

pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8

pip install -q tf-nightly
import tensorflow as tf
import tensorflow.compat.v1 as v1
import sys
import numpy as np

from contextlib import contextmanager
2022-12-14 20:38:40.633093: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay

track_tf1_style_variables 装饰器

本指南中介绍的关键 shim 是 tf.compat.v1.keras.utils.track_tf1_style_variables,这是一个装饰器,您可以在属于 tf.keras.layers.Layertf.Module 的方法中利用它来跟踪 TF1.x 样式的权重和捕获正则化损失。

使用 tf.compat.v1.keras.utils.track_tf1_style_variables 装饰 tf.keras.layers.Layertf.Module 的调用方法允许通过 tf.compat.v1.get_variable(以及扩展程序 tf.compat.v1.layers)在装饰方法内部正常工作,而不是总是在每次调用时创建一个新变量。此外,它还将导致层或模块隐式跟踪通过装饰方法内部的 get_variable 创建或访问的任何权重。

除了在标准 layer.variable/module.variable/ 等属性下跟踪权重本身外,如果该方法属于 tf.keras.layers.Layer,则通过 get_variabletf.compat.v1.layers 正则化器参数指定的任何正则化损失都将由标准 layer.losses 属性下的层跟踪。

即使启用了 TF2 行为,这种跟踪机制也允许在 TF2 中的 Keras 层或 tf.Module 内使用大量 TF1.x 样式的模型前向传递代码。

用法示例

下面的用法示例演示了用于装饰 tf.keras.layers.Layer 方法的建模 shim,但除了它们与 Keras 功能特别交互的情况外,它们在装饰 tf.Module 方法时也适用。

使用 tf.compat.v1.get_variable 构建的层

想象一下,您有一个直接在 tf.compat.v1.get_variable 上实现的层,代码如下所示:

def dense(self, inputs, units):
  out = inputs
  with tf.compat.v1.variable_scope("dense"):
    # The weights are created with a `regularizer`,
    kernel = tf.compat.v1.get_variable(
        shape=[out.shape[-1], units],
        regularizer=tf.keras.regularizers.L2(),
        initializer=tf.compat.v1.initializers.glorot_normal,
        name="kernel")
    bias = tf.compat.v1.get_variable(
        shape=[units,],
        initializer=tf.compat.v1.initializers.zeros,
        name="bias")
    out = tf.linalg.matmul(out, kernel)
    out = tf.compat.v1.nn.bias_add(out, bias)
  return out

使用 shim 将其转换成一个层并在输入上调用它。

class DenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

layer = DenseLayer(10)
x = tf.random.normal(shape=(8, 20))
layer(x)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_39827/795621215.py:7: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[-1.5605565 , -0.6923042 , -0.12178189, -0.9827505 ,  0.40099502,
         2.047494  ,  0.10356927,  0.8067764 ,  1.9125116 , -0.1240418 ],
       [-0.5351726 ,  0.1451138 , -2.6096716 ,  2.517344  , -0.39882976,
         1.0957274 ,  3.0046465 ,  1.1085964 ,  2.1165054 , -1.5790572 ],
       [ 1.6604784 , -0.39988056,  3.2205338 , -1.100543  , -0.29846603,
        -0.84717214,  0.16870266, -0.62517273,  1.0294147 , -0.6185304 ],
       [-1.2233133 ,  1.4100406 ,  1.1022731 ,  1.5311328 , -1.2576067 ,
        -0.2567054 ,  0.6769518 , -0.8835863 , -1.2494072 , -0.98312885],
       [ 0.4366729 ,  0.41180784, -0.5664675 ,  1.0509853 ,  1.3133334 ,
        -1.2714801 , -0.1099126 , -0.2339085 , -0.49433595, -2.5287502 ],
       [ 0.72479874,  0.55833936, -0.30833197,  0.89063543, -0.5979869 ,
        -1.6587179 ,  1.9928468 ,  0.16428874,  0.6147038 , -0.08243904],
       [-0.16928086, -0.0798564 , -0.87984407, -0.10892534,  1.1293844 ,
        -0.07888462, -0.15026242, -0.02026159, -0.49940336,  0.72387445],
       [ 0.40993834,  0.06033725, -1.5480766 ,  0.42014968, -1.4231853 ,
        -0.2830826 , -0.21677351, -1.3919728 , -0.37258464,  0.34734738]],
      dtype=float32)>

像标准 Keras 层一样访问跟踪的变量和捕获的正则化损失。

layer.trainable_variables
layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.14751506>]

为了确保权重在每次调用该层时都得到重用,请将所有权重设置为零,然后再次调用该层。

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

# Note: layer.losses is not a live view and
# will get reset only at each layer call
print("layer.losses:", layer.losses)
print("calling layer again.")
out = layer(x)
print("layer.losses: ", layer.losses)
out
Resetting variables to zero: ['dense/bias:0', 'dense/kernel:0']
layer.losses: [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
calling layer again.
layer.losses:  [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

您也可以直接在 Keras 函数式模型构造中使用转换后的层。

inputs = tf.keras.Input(shape=(20))
outputs = DenseLayer(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 20))
model(x)

# Access the model variables and regularization losses
model.weights
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.1256916>]

使用 tf.compat.v1.layers 构建的模型

想象一下,您有一个直接在 tf.compat.v1.layers 上实现的层或模型,代码如下所示:

def model(self, inputs, units):
  with tf.compat.v1.variable_scope('model'):
    out = tf.compat.v1.layers.conv2d(
        inputs, 3, 3,
        kernel_regularizer="l2")
    out = tf.compat.v1.layers.flatten(out)
    out = tf.compat.v1.layers.dense(
        out, units,
        kernel_regularizer="l2")
    return out

使用 shim 将其转换成一个层并在输入上调用它。

class CompatV1LayerModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

layer = CompatV1LayerModel(10)
x = tf.random.normal(shape=(8, 5, 5, 5))
layer(x)
/tmpfs/tmp/ipykernel_39827/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_39827/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_39827/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[ 0.8099225 , -0.4621194 ,  1.2551047 , -0.5378463 ,  0.8310043 ,
        -1.2975701 ,  0.48249668, -3.1575112 ,  0.33687246,  0.3826899 ],
       [-1.2930417 ,  0.81548697,  0.9335297 ,  0.6842202 , -1.4681178 ,
        -1.0968316 ,  0.5655617 ,  1.7122681 ,  1.8381269 , -0.63145494],
       [-0.52436864,  2.779965  , -1.4695735 , -0.44274795, -1.3342011 ,
        -0.3335649 , -1.0413288 ,  1.0533664 , -0.27863216,  0.6775204 ],
       [ 0.5044536 ,  2.9041598 , -0.9781048 , -0.94305193, -1.2111641 ,
        -1.1327515 , -0.10927999, -0.22824892,  0.13607529,  0.89048004],
       [-0.10465914, -0.3905803 ,  0.049198  ,  1.8102224 ,  1.1085366 ,
         1.0383061 , -0.02206957,  0.85261726, -1.4558244 ,  2.2443192 ],
       [-0.9532844 ,  0.35003638, -2.4970584 ,  2.1875196 , -1.4162731 ,
        -2.1784813 ,  1.4075698 , -0.08591622,  0.27371824, -2.1132326 ],
       [ 0.6021228 , -0.07931077,  0.32574403, -0.7015223 , -0.7512918 ,
        -0.8090282 , -0.7811756 ,  0.30745095,  1.2039661 ,  1.4830457 ],
       [-0.28259408, -0.31954885, -0.31986874,  1.2672653 ,  1.2030227 ,
        -0.14623904, -0.80953586, -1.9146562 , -1.2992917 ,  2.1242828 ]],
      dtype=float32)>

警告:出于安全原因,请确保将所有 tf.compat.v1.layers 都置于非空字符串 variable_scope 内。这是因为具有自动生成名称的 tf.compat.v1.layers 将始终在任何变量范围之外使名称自动递增。这意味着每次调用层/模块时请求的变量名称都会不匹配。因此,它不会重用已经创建的权重,而是会在每次调用时创建一组新的变量。

像标准 Keras 层一样访问跟踪的变量和捕获的正则化损失。

layer.trainable_variables
layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.040086407>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.14862967>]

为了确保权重在每次调用该层时都得到重用,请将所有权重设置为零,然后再次调用该层。

print("Resetting variables to zero:", [var.name for var in layer.trainable_variables])

for var in layer.trainable_variables:
  var.assign(var * 0.0)

out = layer(x)
print("layer.losses: ", layer.losses)
out
Resetting variables to zero: ['model/conv2d/bias:0', 'model/conv2d/kernel:0', 'model/dense/bias:0', 'model/dense/kernel:0']
layer.losses:  [<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>]
/tmpfs/tmp/ipykernel_39827/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_39827/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_39827/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

您也可以直接在 Keras 函数式模型构造中使用转换后的层。

inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1LayerModel(10)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

x = tf.random.normal(shape=(8, 5, 5, 5))
model(x)
/tmpfs/tmp/ipykernel_39827/2388460905.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
/tmpfs/tmp/ipykernel_39827/2388460905.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_39827/2388460905.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
<tf.Tensor: shape=(8, 10), dtype=float32, numpy=
array([[-0.5349148 ,  1.3527144 , -1.5674714 ,  0.3739456 ,  0.70822996,
        -2.57449   , -0.91658115,  0.01930225, -0.79617566,  3.2843132 ],
       [ 0.4140386 , -0.66665673,  2.861754  ,  1.9320178 , -1.3856443 ,
         2.9440308 ,  1.3543379 , -0.25877574, -1.0080137 , -4.304592  ],
       [ 0.03068304, -1.148419  , -1.584249  , -0.24597901, -0.1047827 ,
         0.7593119 ,  1.6566473 ,  1.9456444 ,  0.71495557,  0.5902455 ],
       [-1.2089777 ,  0.49164614,  0.6760783 ,  1.4208341 ,  0.20729658,
         0.08904815, -0.3644415 , -0.4699819 ,  0.5669506 ,  0.08861631],
       [ 0.25653136, -0.651753  ,  1.045647  , -0.5686461 ,  0.47781086,
         2.6567998 , -1.5720563 , -0.9050565 , -0.31101042, -1.1717727 ],
       [-1.2310287 , -0.67586654, -1.889077  ,  1.3088472 , -0.9956445 ,
        -0.45156097, -1.9318616 ,  1.1491737 ,  0.16780543,  1.8678865 ],
       [ 0.00938642,  0.40236259, -0.41107404, -1.899723  , -0.9135215 ,
        -2.081637  , -0.71667486,  0.34614545,  0.25766808, -0.3802871 ],
       [-0.5589206 ,  0.6831752 , -0.37517905, -1.0286998 ,  0.06483352,
         3.6955824 ,  1.5183566 , -0.71009344,  2.2557702 , -4.1322536 ]],
      dtype=float32)>
# Access the model variables and regularization losses
model.weights
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.04052762>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.15711266>]

捕获批次归一化更新和模型 training 参数

在 TF1.x 中,您可以按如下方式执行批次归一化:

  x_norm = tf.compat.v1.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op = optimizer.minimize(loss)
  train_op = tf.group([train_op, update_ops])

注意:

  1. 批次归一化移动平均值更新由与层分开调用的 get_collection 跟踪
  2. tf.compat.v1.layers.batch_normalization 需要一个 training 参数(使用 TF-Slim 批次归一化层时一般称为 is_training

在 TF2 中,由于 Eager Execution 和自动控制依赖项,批次归一化移动平均值更新将立即执行。无需从更新集合中单独收集它们并将它们添加为显式控制依赖项。

此外,如果您为 tf.keras.layers.Layer 的前向传递方法提供一个 training 参数,Keras 能够将当前训练阶段和任何嵌套层传递给它,就像它对任何其他层所做的那样。有关 Keras 如何处理 training 参数的更多信息,请参阅 tf.keras.Model 的 API 文档。

如果您正在装饰 tf.Module 方法,则需要确保根据需要手动传递所有 training 参数。但是,批次归一化移动平均值更新仍将自动应用,无需显式控制依赖项。

以下代码段演示了如何在 shim 中嵌入批次归一化层以及如何在 Keras 模型中使用它(适用于 tf.keras.layers.Layer)。

class CompatV1BatchNorm(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    print("Forward pass called with `training` =", training)
    with v1.variable_scope('batch_norm_layer'):
      return v1.layers.batch_normalization(x, training=training)
print("Constructing model")
inputs = tf.keras.Input(shape=(5, 5, 5))
outputs = CompatV1BatchNorm()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

print("Calling model in inference mode")
x = tf.random.normal(shape=(8, 5, 5, 5))
model(x, training=False)

print("Moving average variables before training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})

# Notice that when running TF2 and eager execution, the batchnorm layer directly
# updates the moving averages while training without needing any extra control
# dependencies
print("calling model in training mode")
model(x, training=True)

print("Moving average variables after training: ",
      {var.name: var.read_value() for var in model.non_trainable_variables})
Constructing model
Forward pass called with `training` = None
Calling model in inference mode
/tmpfs/tmp/ipykernel_39827/3053504896.py:7: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
  return v1.layers.batch_normalization(x, training=training)
Forward pass called with `training` = False
Moving average variables before training:  {'batch_norm_layer/batch_normalization/moving_mean:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 0., 0., 0., 0.], dtype=float32)>, 'batch_norm_layer/batch_normalization/moving_variance:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([1., 1., 1., 1., 1.], dtype=float32)>}
calling model in training mode
Forward pass called with `training` = True
Moving average variables after training:  {'batch_norm_layer/batch_normalization/moving_mean:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([-0.00018271, -0.00025929, -0.00081811, -0.00030542,  0.00078999],
      dtype=float32)>, 'batch_norm_layer/batch_normalization/moving_variance:0': <tf.Tensor: shape=(5,), dtype=float32, numpy=
array([1.0005931 , 1.001202  , 1.0005841 , 0.99828464, 1.0006655 ],
      dtype=float32)>}

基于变量范围的变量重用

在基于 get_variable 的前向传递中创建的任何变量都将保持与 TF1.x 中变量作用域相同的变量命名和重用语义。只要任何具有自动生成名称的 tf.compat.v1.layers 至少有一个非空的外部范围,情况就如上面所述。

注:命名和重用的范围将限定在单个层/模块实例内。在一个 shim 装饰的层或模块内调用 get_variable 将无法引用在层或模块内创建的变量。如果需要,您可以通过直接使用 Python 对其他变量的引用来解决此问题,而不是通过 get_variable 访问变量。

Eager Execution 和 tf.function

如上所示,tf.keras.layers.Layertf.Module 的装饰方法在 Eager Execution 内部运行,并且也与 tf.function 兼容。这意味着您可以使用 pdb 和其他交互式工具在前向传递运行时单步执行。

警告:尽管从 tf.function 内部调用 shim 装饰的层/模块方法是完全安全的,但如果这些 tf.functions 包含 get_variable 调用,则将 tf.function 置于 shim 装饰的方法中是不安全的。进入 tf.function 会重置 variable_scope,这意味着 shim 模仿的 TF1.x 样式基于变量范围的变量重用将在此设置中失效。

分布策略

@track_tf1_style_variables 装饰的层或模块方法中调用 get_variable 会在底层使用标准 tf.Variable 变量创建。这意味着您可以将它们与 MirroredStrategyTPUStrategytf.distribute 提供的各种分发策略一起使用。

在装饰调用中嵌套 tf.Variabletf.Moduletf.keras.layerstf.keras.models

tf.compat.v1.keras.utils.track_tf1_style_variables 中装饰您的层调用只会添加对通过 tf.compat.v1.get_variable 创建(和重用)的变量的自动隐式跟踪。它不会捕获由 tf.Variable 调用直接创建的权重,例如典型的 Keras 层和大多数 tf.Module 使用的权重。本部分介绍如何处理这些嵌套情况。

(预先存在的用法)tf.keras.layerstf.keras.models

对于嵌套 Keras 层和模型的预先存在的用法,请使用 tf.compat.v1.keras.utils.get_or_create_layer。这仅建议用于简化现有 TF1.x 嵌套 Keras 用法的迁移;新代码应当使用如下所述的 tf.Variables 和 tf.Modules 的显式特性设置。

要使用 tf.compat.v1.keras.utils.get_or_create_layer,请将构造嵌套模型的代码封装到一个方法内,并将其传递给该方法。示例如下:

class NestedModel(tf.keras.Model):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
        kernel_initializer=tf.compat.v1.ones_initializer())
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    # Get or create a nested model without assigning it as an explicit property
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

layer = NestedModel(10)
layer(tf.ones(shape=(5,5)))
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]], dtype=float32)>

这种方法可确保这些嵌套层被 TensorFlow 正确重用和跟踪。请注意,在适当的方法上仍然需要 @track_tf1_style_variables 装饰器。传递给 get_or_create_layer 的模型构建器方法(在本例中为 self.build_model)不应带参数。

跟踪权重:

assert len(layer.weights) == 2
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"dense/bias:0", "dense/kernel:0"}

layer.weights
[<tf.Variable 'dense/kernel:0' shape=(5, 10) dtype=float32, numpy=
 array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>,
 <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]

以及正则化损失:

tf.add_n(layer.losses)
<tf.Tensor: shape=(), dtype=float32, numpy=0.5>

增量迁移:tf.Variablestf.Modules

如果您需要在修饰方法中嵌入 tf.Variable 调用或 tf.Module(例如,如果您遵循本指南后面介绍的向非传统 TF2 API 的增量迁移),您仍然需要根据下面的要求显式跟踪它们:

  • 显式确保变量/模块/层只创建一次
  • 就像定义典型模块或层时一样,将它们显式附加为实例特性
  • 在后续调用中显式重用已创建的对象

这确保了每次调用都不会创建新的权重并且可以正确地重用权重。此外,这还可以确保跟踪现有的权重和正则化损失。

下面是一个展现外观的示例:

class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def __call__(self, inputs):
    out = inputs
    with tf.compat.v1.variable_scope("inner_dense"):
      # The weights are created with a `regularizer`,
      # so the layer should track their regularization losses
      kernel = tf.compat.v1.get_variable(
          shape=[out.shape[-1], self.units],
          regularizer=tf.keras.regularizers.L2(),
          initializer=tf.compat.v1.initializers.glorot_normal,
          name="kernel")
      bias = tf.compat.v1.get_variable(
          shape=[self.units,],
          initializer=tf.compat.v1.initializers.zeros,
          name="bias")
      out = tf.linalg.matmul(out, kernel)
      out = tf.compat.v1.nn.bias_add(out, bias)
    return out

class WrappedDenseLayer(tf.keras.layers.Layer):

  def __init__(self, units, **kwargs):
    super().__init__(**kwargs)
    self.units = units
    # Only create the nested tf.variable/module/layer/model
    # once, and then reuse it each time!
    self._dense_layer = NestedLayer(self.units)

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('outer'):
      outputs = tf.compat.v1.layers.dense(inputs, 3)
      outputs = tf.compat.v1.layers.dense(inputs, 4)
      return self._dense_layer(outputs)

layer = WrappedDenseLayer(10)

layer(tf.ones(shape=(5, 5)))
/tmpfs/tmp/ipykernel_39827/2765428776.py:38: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  outputs = tf.compat.v1.layers.dense(inputs, 3)
/tmpfs/tmp/ipykernel_39827/2765428776.py:39: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  outputs = tf.compat.v1.layers.dense(inputs, 4)
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[-0.2482834 ,  0.3641441 , -0.9629273 , -0.75598675,  1.2324747 ,
         0.588566  ,  0.09184126, -1.1626451 ,  0.29253644,  0.2593758 ],
       [-0.2482834 ,  0.3641441 , -0.9629273 , -0.75598675,  1.2324747 ,
         0.588566  ,  0.09184126, -1.1626451 ,  0.29253644,  0.2593758 ],
       [-0.2482834 ,  0.3641441 , -0.9629273 , -0.75598675,  1.2324747 ,
         0.588566  ,  0.09184126, -1.1626451 ,  0.29253644,  0.2593758 ],
       [-0.2482834 ,  0.3641441 , -0.9629273 , -0.75598675,  1.2324747 ,
         0.588566  ,  0.09184126, -1.1626451 ,  0.29253644,  0.2593758 ],
       [-0.2482834 ,  0.3641441 , -0.9629273 , -0.75598675,  1.2324747 ,
         0.588566  ,  0.09184126, -1.1626451 ,  0.29253644,  0.2593758 ]],
      dtype=float32)>

请注意,即使使用 track_tf1_style_variables 装饰器装饰嵌套模块,也需要显式跟踪它。这是因为带有修饰方法的每个模块/层都有自己的变量存储与之关联。

正确跟踪权重:

assert len(layer.weights) == 6
weights = {x.name: x for x in layer.variables}

assert set(weights.keys()) == {"outer/inner_dense/bias:0",
                               "outer/inner_dense/kernel:0",
                               "outer/dense/bias:0",
                               "outer/dense/kernel:0",
                               "outer/dense_1/bias:0",
                               "outer/dense_1/kernel:0"}

layer.trainable_weights
[<tf.Variable 'outer/inner_dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/inner_dense/kernel:0' shape=(4, 10) dtype=float32, numpy=
 array([[ 0.09730561,  0.17783208,  0.11119158,  0.45986557, -0.1447215 ,
          0.16842431,  0.28586477, -0.20580281, -0.27836105,  0.17098187],
        [-0.4046081 , -0.2816257 , -0.29024273, -0.16747962, -0.10818168,
          0.17109518,  0.253033  , -0.24617101,  0.13494436,  0.36550042],
        [ 0.6310229 ,  0.33682218, -0.6431416 , -0.11476022,  0.5231503 ,
          0.08794093, -0.81403106,  0.2009578 , -0.64648336, -0.08593968],
        [ 0.6341598 ,  0.1530097 ,  0.09208723,  0.4000943 , -0.2221118 ,
         -0.20686378, -0.5717175 ,  0.6668123 , -0.64983076, -0.21107097]],
       dtype=float32)>,
 <tf.Variable 'outer/dense/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/dense/kernel:0' shape=(5, 3) dtype=float32, numpy=
 array([[ 0.8620253 ,  0.46871334, -0.70384526],
        [ 0.8651132 ,  0.4123873 , -0.81018174],
        [-0.5762091 , -0.41468212,  0.6904549 ],
        [ 0.01172829, -0.34276652, -0.01905262],
        [ 0.8372217 , -0.0877108 ,  0.7009117 ]], dtype=float32)>,
 <tf.Variable 'outer/dense_1/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'outer/dense_1/kernel:0' shape=(5, 4) dtype=float32, numpy=
 array([[ 0.03267503, -0.5882085 ,  0.41363847, -0.50928485],
        [ 0.17197746,  0.20136333, -0.623744  , -0.3406625 ],
        [-0.05342215, -0.602049  ,  0.5387504 , -0.0299927 ],
        [ 0.57918274,  0.5877584 ,  0.7761904 , -0.5445967 ],
        [-0.2553717 , -0.0379287 ,  0.3576846 , -0.77528757]],
       dtype=float32)>]

以及正则化损失:

layer.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.054905508>]

请注意,如果 NestedLayer 是非 Keras tf.Module,仍会跟踪变量,但不会自动跟踪正则化损失,因此您必须单独显式跟踪它们。

变量名称指南

显式 tf.Variable 调用和 Keras 层使用不同于 get_variablevariable_scopes 组合中的层名/变量名自动生成机制。尽管即使从 TF1.x 计算图转到 TF2 Eager Execution 和 tf.function,shim 也会使您的变量名称与 get_variable 创建的变量匹配,但它无法保证为 tf.Variable 调用和 Keras 层生成的变量名称与您嵌入到方法装饰器中的变量名称相同。多个变量甚至可以在 TF2 Eager Execution 和 tf.function 中共享相同的名称。

在本指南后面有关验证正确性和映射 TF1.x 检查点的部分中,您应该特别注意这一点。

在装饰方法中使用 tf.compat.v1.make_template

强烈建议您直接使用 tf.compat.v1.keras.utils.track_tf1_style_variables 而不是使用 tf.compat.v1.make_template,因为它是 TF2 上一个较薄的层

请按照本部分中的指南获取已经依赖于 tf.compat.v1.make_template 的先前 TF1.x 代码。

由于 tf.compat.v1.make_template 封装使用 get_variable 的代码,track_tf1_style_variables 装饰器允许您在层调用中使用这些模板并成功跟踪权重和正则化损失。

但是,请确保只调用一次 make_template,然后在每个层调用中重用相同的模板。否则,每次调用层时都会创建一个新模板以及一组新变量。

例如:

class CompatV1TemplateScaleByY(tf.keras.layers.Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    def my_op(x, scalar_name):
      var1 = tf.compat.v1.get_variable(scalar_name,
                            shape=[],
                            regularizer=tf.compat.v1.keras.regularizers.L2(),
                            initializer=tf.compat.v1.constant_initializer(1.5))
      return x * var1
    self.scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y')

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('layer'):
      # Using a scope ensures the `scale_by_y` name will not be incremented
      # for each instantiation of the layer.
      return self.scale_by_y(inputs)

layer = CompatV1TemplateScaleByY()

out = layer(tf.ones(shape=(2, 3)))
print("weights:", layer.weights)
print("regularization loss:", layer.losses)
print("output:", out)
weights: [<tf.Variable 'layer/scale_by_y/y:0' shape=() dtype=float32, numpy=1.5>]
regularization loss: [<tf.Tensor: shape=(), dtype=float32, numpy=0.022499999>]
output: tf.Tensor(
[[1.5 1.5 1.5]
 [1.5 1.5 1.5]], shape=(2, 3), dtype=float32)

警告:避免在多个层实例之间共享 make_template 创建的相同模板,因为它可能会破坏 shim 装饰器的变量和正则化损失跟踪机制。此外,如果您计划在多个层实例中使用相同的 make_template 名称,那么您应当将所创建模板的用法嵌套在 variable_scope 内。如果不这么做,则为模板的 variable_scope 生成的名称将随着层的每个新实例而递增。这可能会以意想不到的方式改变权重名称。

到原生 TF2 的增量迁移

如前文所述,track_tf1_style_variables 允许您将 TF2 样式的面向对象的 tf.Variable/tf.keras.layers.Layer/tf.Module 用法与传统的 tf.compat.v1.get_variable/tf.compat.v1.layers 样式用法混合在同一个装饰模块/层内。

这意味着在使 TF1.x 模型与 TF2 完全兼容后,您可以使用原生(非 tf.compat.v1)TF2 API 编写所有新模型组件,并让它们与旧代码互操作。

但是,如果您继续修改旧模型组件,您也可以选择将传统样式的 tf.compat.v1 用法逐步切换到推荐用于新编写的 TF2 代码的纯原生面向对象 API。

tf.compat.v1.get_variable 用法可以替换为 self.add_weight 调用(如果您正在装饰 Keras 层/模型),或者 tf.Variable 调用(如果您正在装饰 Keras 对象或 tf.Module)。

函数式和面向对象的 tf.compat.v1.layers 通常都可以替换为等效的 tf.keras.layers 层,无需更改任何参数。

在逐步迁移到本身可能使用 track_tf1_style_variables 的纯原生 API 的过程中,您还可以考虑将模型的一部分或常见模式拆分单独的层/模块。

关于 Slim 和 contrib.layers 的注意事项

大量早期 TF 1.x 代码使用 Slim 库,该库与 TF 1.x 一起打包为 tf.contrib.layers。使用 Slim 将代码转换为原生 TF 2 比转换 v1.layers 更复杂。事实上,先将您的 Slim 代码转换为 v1.layers,然后再转换为 Keras 可能是有意义的。下面是一些转换 Slim 代码的一般指南。

  • 确保所有参数都是显式的。如果可能,移除 arg_scopes。如果您仍然需要使用它们,请将 normalizer_fnactivation_fn 拆分为它们自己的层。
  • 可分离的卷积层映射至一个或多个不同的 Keras 层(深度、逐点和可分离 Keras 层)
  • Slim 与 v1.layers 有不同的参数名和默认值。
  • 请注意,某些参数具有不同的比例。

在忽略检查点兼容性的情况下迁移到原生 TF2

以下代码示例演示了在不考虑检查点兼容性的情况下将模型逐步迁移到纯原生 API。

class CompatModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

接下来,以分段方式将 compat.v1 API 替换为其原生的面向对象的对应项。首先,将卷积层切换为在层构造函数中创建的 Keras 对象。

class PartiallyMigratedModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = tf.compat.v1.layers.flatten(out)
      out = tf.compat.v1.layers.dropout(out, training=training)
      out = tf.compat.v1.layers.dense(
          out, self.units,
          kernel_regularizer="l2")
      return out

使用 v1.keras.utils.DeterministicRandomTestTool 类来验证这种增量更改是否使模型具有与以前相同的行为。

random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = CompatModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  original_output = layer(inputs)

  # Grab the regularization loss as well
  original_regularization_loss = tf.math.add_n(layer.losses)

print(original_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
/tmpfs/tmp/ipykernel_39827/355611412.py:10: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_39827/355611412.py:13: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_39827/355611412.py:14: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  out = tf.compat.v1.layers.dropout(out, training=training)
/tmpfs/tmp/ipykernel_39827/355611412.py:15: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = PartiallyMigratedModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
/tmpfs/tmp/ipykernel_39827/3237389364.py:14: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  out = tf.compat.v1.layers.flatten(out)
/tmpfs/tmp/ipykernel_39827/3237389364.py:15: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  out = tf.compat.v1.layers.dropout(out, training=training)
/tmpfs/tmp/ipykernel_39827/3237389364.py:16: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  out = tf.compat.v1.layers.dense(
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

您现在已经用原生 Keras 层替换了所有单独的 compat.v1.layers

class NearlyFullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_layer(inputs)
      out = self.flatten_layer(out)
      out = self.dense_layer(out)
      return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = NearlyFullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

最后,移除任何其余的(不再需要的)variable_scope 用法和 track_tf1_style_variables 装饰器本身。

您现在得到了一个完全使用原生 API 的模型版本。

class FullyNativeModel(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units
    self.conv_layer = tf.keras.layers.Conv2D(
      3, 3,
      kernel_regularizer="l2")
    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_layer = tf.keras.layers.Dense(
      self.units,
      kernel_regularizer="l2")

  def call(self, inputs):
    out = self.conv_layer(inputs)
    out = self.flatten_layer(out)
    out = self.dense_layer(out)
    return out
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
  tf.keras.utils.set_random_seed(42)
  layer = FullyNativeModel(10)

  inputs = tf.random.normal(shape=(10, 5, 5, 5))
  migrated_output = layer(inputs)

  # Grab the regularization loss as well
  migrated_regularization_loss = tf.math.add_n(layer.losses)

print(migrated_regularization_loss)
tf.Tensor(0.1824967, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
np.testing.assert_allclose(original_regularization_loss.numpy(), migrated_regularization_loss.numpy())
np.testing.assert_allclose(original_output.numpy(), migrated_output.numpy())

在迁移到原生 TF2 期间保持检查点兼容性

上述向原生 TF2 API 的迁移过程更改了变量名称(因为 Keras API 产生了极为不同的权重名称),以及指向模型中不同权重的面向对象路径。这些更改的影响是它们将破坏任何现有 TF1 样式的基于名称的检查点或 TF2 样式的面向对象的检查点。

但是,在某些情况下,您可以使用原始的基于名称的检查点,并使用重用 TF1.x 检查点指南中详述的方式找到变量与其新名称的映射。

使这种方法可行的一些技巧如下:

  • 变量仍然具有可以设置的 name 参数。
  • Keras 模型还采用 name 参数,并将其设置为变量的前缀。
  • v1.name_scope 函数可用于设置变量名前缀,这与 tf.variable_scope 截然不同。它只影响名称,而不跟踪变量和重用。

考虑到以上几点,以下示例代码演示了一个工作流,您可以调整您的代码以增量更新模型的一部分,同时更新检查点。

注:由于使用 Keras 层命名变量的复杂性,这不能保证适用于所有用例。

  1. 首先,将函数式 tf.compat.v1.layers 切换到面向对象的版本。
class FunctionalStyleCompatModel(tf.keras.layers.Layer):

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = tf.compat.v1.layers.conv2d(
          inputs, 3, 3,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 4, 4,
          kernel_regularizer="l2")
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = FunctionalStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_39827/1716504801.py:6: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_39827/1716504801.py:9: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
/tmpfs/tmp/ipykernel_39827/1716504801.py:12: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['model/conv2d/bias:0',
 'model/conv2d/kernel:0',
 'model/conv2d_1/bias:0',
 'model/conv2d_1/kernel:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']
  1. 接下来,将 compat.v1.layer 对象和由 compat.v1.get_variable 创建的任何变量分配为 tf.keras.layers.Layer/tf.Module 对象的属性,其方法用 track_tf1_style_variables 装饰(注意,任何面向对象的 TF2 样式检查点现在会同时保存按变量名的路径和新的面向对象的路径)。
class OOStyleCompatModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.compat.v1.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.compat.v1.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = OOStyleCompatModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_39827/1693875107.py:17: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['model/conv2d/kernel:0',
 'model/conv2d/bias:0',
 'model/conv2d_1/kernel:0',
 'model/conv2d_1/bias:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']
  1. 此时重新保存加载的检查点,以按照变量名称(对于 compat.v1.layers)或面向对象的对象计算图保存路径。
weights = {v.name: v for v in layer.weights}
assert weights['model/conv2d/kernel:0'] is layer.conv_1.kernel
assert weights['model/conv2d_1/bias:0'] is layer.conv_2.bias
  1. 您现在可以将面向对象的 compat.v1.layers 替换为原生 Keras 层,同时仍然能够加载最近保存的检查点。通过继续记录被替换层的自动生成的 variable_scopes,确保为其余的 compat.v1.layers 保留变量名称。这些切换的层/变量现在将仅使用检查点中变量的对象特性路径,而不是变量名称路径。

通常,您可以通过以下方式替换附加到属性的变量中的 compat.v1.get_variable 用法:

  • 将它们切换为使用 tf.Variable或者
  • 使用 tf.keras.layers.Layer.add_weight 更新它们。请注意,如果您没有一次性切换所有层,这可能会更改缺少 name 参数的其余 compat.v1.layers 的自动生成层/变量命名。如果是这种情况,您必须通过手动打开和关闭与已移除的 compat.v1.layer 生成的作用域名称相对应的 variable_scope 来保持其余 compat.v1.layers 的变量名称相同。否则,来自现有检查点的路径可能会发生冲突,并且检查点加载的行为不正确。
def record_scope(scope_name):
  """Record a variable_scope to make sure future ones get incremented."""
  with tf.compat.v1.variable_scope(scope_name):
    pass

class PartiallyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      record_scope('conv2d') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = self.conv_2(out)
      record_scope('conv2d_1') # Only needed if follow-on compat.v1.layers do not pass a `name` arg
      out = tf.compat.v1.layers.conv2d(
          out, 5, 5,
          kernel_regularizer="l2")
      return out

layer = PartiallyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
/tmpfs/tmp/ipykernel_39827/3143218429.py:24: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  out = tf.compat.v1.layers.conv2d(
['partially_native_keras_layers_model/model/conv2d_13/kernel:0',
 'partially_native_keras_layers_model/model/conv2d_13/bias:0',
 'partially_native_keras_layers_model/model/conv2d_14/kernel:0',
 'partially_native_keras_layers_model/model/conv2d_14/bias:0',
 'model/conv2d_2/bias:0',
 'model/conv2d_2/kernel:0']

在构造变量后,在这一步保存检查点将使其包含当前可用的对象路径。

确保记录已移除的 compat.v1.layers 的作用域,以便为其余的 compat.v1.layers 保留自动生成的权重名称。

weights = set(v.name for v in layer.weights)
assert 'model/conv2d_2/kernel:0' in weights
assert 'model/conv2d_2/bias:0' in weights
  1. 重复上述步骤,直到您将模型中的所有 compat.v1.layerscompat.v1.get_variable 替换为完全原生的对应项。
class FullyNativeKerasLayersModel(tf.keras.layers.Layer):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv_1 = tf.keras.layers.Conv2D(
          3, 3,
          kernel_regularizer="l2")
    self.conv_2 = tf.keras.layers.Conv2D(
          4, 4,
          kernel_regularizer="l2")
    self.conv_3 = tf.keras.layers.Conv2D(
          5, 5,
          kernel_regularizer="l2")


  def call(self, inputs, training=None):
    with tf.compat.v1.variable_scope('model'):
      out = self.conv_1(inputs)
      out = self.conv_2(out)
      out = self.conv_3(out)
      return out

layer = FullyNativeKerasLayersModel()
layer(tf.ones(shape=(10, 10, 10, 10)))
[v.name for v in layer.weights]
['fully_native_keras_layers_model/model/conv2d_16/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_16/bias:0',
 'fully_native_keras_layers_model/model/conv2d_17/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_17/bias:0',
 'fully_native_keras_layers_model/model/conv2d_18/kernel:0',
 'fully_native_keras_layers_model/model/conv2d_18/bias:0']

请记得进行测试以确保新更新的检查点的行为仍然符合预期。在此过程的每个增量步骤中应用验证数字正确性指南中介绍的技术,以确保您的迁移代码正确运行。

处理建模 shim 未涵盖的 TF1.x 到 TF2 行为更改

本指南中介绍的建模 slim 可以确保使用 get_variabletf.compat.v1.layersvariable_scope 语义创建的变量、层和正则化损失在使用 Eager Execution 和 tf.function 时继续像以前一样有效,无需依赖集合。

本文未涵盖您的模型前向传递可能依赖的所有 TF1.x 特定语义。在某些情况下,shim 可能不足以让您的模型前向传递在 TF2 中自行运行。阅读 TF1.x 与 TF2 行为指南,详细了解 TF1.x 与 TF2 之间的行为差异。