TensorFlow Addons 层:WeightNormalization

在 TensorFlow.org 上查看 在 Google Colab 中运行 View source on GitHub Download notebook

概述

此笔记本将演示如何使用权重归一化层以及如何提升收敛。

WeightNormalization

用于加速深度神经网络训练的一个简单的重新参数化:

Tim Salimans、Diederik P. Kingma (2016)

通过以这种方式重新参数化权重,我们改善了优化问题的条件,并加快了随机梯度下降的收敛速度。我们的重新参数化受到批次归一化的启发,但没有在 minibatch 中的样本之间引入任何依赖关系。这意味着我们的方法也可以成功应用于循环模型(例如 LSTM)和对噪声敏感的应用(例如深度强化学习或生成模型),而批次归一化则不太适合这类模型和应用。尽管我们的方法要简单得多,但它仍可以在很大程度上为完整批次归一化提供加速。另外,我们的方法计算开销较低,从而可以在相同的时间内执行更多优化步骤。

https://arxiv.org/abs/1602.07868



设置

pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt
# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10

构建模型

# Standard ConvNet
reg_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(6, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(16, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax'),
])
# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])

加载数据

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 6s 0us/step

训练模型

reg_model.compile(optimizer='adam', 
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

reg_history = reg_model.fit(x_train, y_train,
                            batch_size=batch_size,
                            epochs=epochs,
                            validation_data=(x_test, y_test),
                            shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.6481 - accuracy: 0.3959 - val_loss: 1.5032 - val_accuracy: 0.4505
Epoch 2/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.3405 - accuracy: 0.5207 - val_loss: 1.2814 - val_accuracy: 0.5434
Epoch 3/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.2279 - accuracy: 0.5627 - val_loss: 1.2407 - val_accuracy: 0.5507
Epoch 4/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.1465 - accuracy: 0.5939 - val_loss: 1.2135 - val_accuracy: 0.5650
Epoch 5/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.0887 - accuracy: 0.6156 - val_loss: 1.1859 - val_accuracy: 0.5871
Epoch 6/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.0320 - accuracy: 0.6338 - val_loss: 1.1298 - val_accuracy: 0.6070
Epoch 7/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.9882 - accuracy: 0.6486 - val_loss: 1.1232 - val_accuracy: 0.6112
Epoch 8/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.9439 - accuracy: 0.6669 - val_loss: 1.1192 - val_accuracy: 0.6120
Epoch 9/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.9070 - accuracy: 0.6784 - val_loss: 1.1235 - val_accuracy: 0.6168
Epoch 10/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.8794 - accuracy: 0.6892 - val_loss: 1.1383 - val_accuracy: 0.6184
wn_model.compile(optimizer='adam', 
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

wn_history = wn_model.fit(x_train, y_train,
                          batch_size=batch_size,
                          epochs=epochs,
                          validation_data=(x_test, y_test),
                          shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.5910 - accuracy: 0.4199 - val_loss: 1.3740 - val_accuracy: 0.5097
Epoch 2/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.3134 - accuracy: 0.5292 - val_loss: 1.2580 - val_accuracy: 0.5454
Epoch 3/10
1563/1563 [==============================] - 7s 4ms/step - loss: 1.1945 - accuracy: 0.5743 - val_loss: 1.2153 - val_accuracy: 0.5669
Epoch 4/10
1563/1563 [==============================] - 7s 4ms/step - loss: 1.1125 - accuracy: 0.6043 - val_loss: 1.1817 - val_accuracy: 0.5820
Epoch 5/10
1563/1563 [==============================] - 7s 4ms/step - loss: 1.0498 - accuracy: 0.6264 - val_loss: 1.1949 - val_accuracy: 0.5805
Epoch 6/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.9947 - accuracy: 0.6477 - val_loss: 1.1413 - val_accuracy: 0.5956
Epoch 7/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.9514 - accuracy: 0.6634 - val_loss: 1.1308 - val_accuracy: 0.6085
Epoch 8/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.9040 - accuracy: 0.6810 - val_loss: 1.1338 - val_accuracy: 0.6137
Epoch 9/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.8686 - accuracy: 0.6927 - val_loss: 1.1266 - val_accuracy: 0.6156
Epoch 10/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.8343 - accuracy: 0.7054 - val_loss: 1.1073 - val_accuracy: 0.6233
reg_accuracy = reg_history.history['accuracy']
wn_accuracy = wn_history.history['accuracy']

plt.plot(np.linspace(0, epochs,  epochs), reg_accuracy,
             color='red', label='Regular ConvNet')

plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,
         color='blue', label='WeightNorm ConvNet')

plt.title('WeightNorm Accuracy Comparison')
plt.legend()
plt.grid(True)
plt.show()

png