tf.function과 함께 XLA 사용하기

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기

이 튜토리얼에서는 TensorFlow 모델을 훈련하여 MNIST 데이터세트를 분류합니다. 여기서 훈련 함수는 XLA를 사용하여 컴파일합니다.

먼저, TensorFlow를 로드하고 즉시 실행을 활성화합니다.

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

그런 다음, 몇 가지 필요한 상수를 정의하고 MNIST 데이터세트를 준비합니다.

# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000

# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()

# Casting from raw data to the required datatypes.
def cast(images, labels):
  images = tf.cast(
      tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
  labels = tf.cast(labels, tf.int64)
  return (images, labels)

마지막으로, 모델과 옵티마이저를 정의합니다. 이 모델은 단일 밀집 레이어를 사용합니다.

layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()

훈련 함수 정의하기

훈련 함수에서 위에 정의된 레이어를 사용하여 예측된 레이블을 얻은 다음, 옵티마이저를 사용하여 손실의 그래디언트를 최소화합니다. XLA를 사용하여 계산을 컴파일하려면 experimental_compile=True를 사용하여 tf.function 내에 배치합니다.

@tf.function(experimental_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

모델 훈련 및 테스트하기

훈련 함수를 정의했으면 모델을 정의합니다.

for images, labels in train_ds:
  if optimizer.iterations > TRAIN_STEPS:
    break
  train_mnist(images, labels)

마지막으로, 정확성을 확인합니다.

images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.872, shape=(), dtype=float32)