XLA: 머신러닝을 위한 컴파일러 최적화

XLA(Accelerated Linear Algebra)는 잠재적으로 소스 코드를 변경하지 않고 TensorFlow 모델을 가속화할 수 있는 선형 대수학용 도메인별 컴파일러입니다.

속도와 메모리 사용량 면에서 향상됩니다. 예를 들어 Volta V100 GPU 8개를 사용하여 BERT MLPerf를 제출할 때 XLA를 사용한 경우 최대 7배의 성능 향상과 최대 5배의 배치 크기 향상을 얻었습니다.

소개

TensorFlow 프로그램이 실행될 때 모든 작업은 TensorFlow 실행자에 의해 개별적으로 실행됩니다. 각 TensorFlow 작업에는 실행자가 전달하는 사전 컴파일된 GPU 커널 구현이 있습니다.

XLA는 모델을 실행하는 대체 모드를 제공합니다. 즉, TensorFlow 그래프를 지정된 모델에 맞게 특별히 생성된 일련의 계산 커널로 컴파일합니다. 이러한 커널은 모델마다 고유하므로 최적화를 위해 모델별 정보를 활용할 수 있습니다. 예를 들어 다음과 같이 XLA가 간단한 TensorFlow 계산과 관련하여 실행하는 최적화를 살펴보겠습니다.

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

XLA를 사용하지 않고 실행하면 그래프는 각각 곱하기, 더하기 및 환산에 하나씩 세 가지 커널을 시작합니다. 그러나 XLA는 그래프를 최적화하여 단일 커널 실행에서 결과를 계산할 수 있습니다. 이 계산은 더하기, 곱하기 및 환산을 단일 GPU 커널로 '융합'하여 이루어집니다. 더욱이, 이 융합 작업은 y*zx+y*z에 의해 생성된 중간 값을 메모리에 기록하지 않습니다. 대신 이 중간 계산 결과를 GPU 레지스터에 완전히 유지하면서 사용자에게 직접 '스트리밍'합니다. 융합은 XLA의 가장 중요한 단일 최적화입니다. 메모리 대역폭은 일반적으로 하드웨어 가속기에서 가장 부족한 리소스이므로 메모리 작업을 삭제하는 것이 성능을 향상하는 가장 좋은 방법 중 하나입니다.

TensorFlow 모델에 XLA 사용 설정

tf.function(jit_compile=True)을 사용한 명시적 컴파일

명시적 컴파일 API는 컴파일할 함수를 선택하기 위한 보다 세밀한 제어를 제공합니다. 예를 들어 MNIST 학습을 실행하는 다음 TensorFlow 함수를 XLA로 컴파일합니다.

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

jit_compile API에는 컴파일해야 하는 시맨틱스가 있으며, 전체 함수가 XLA로 컴파일되거나 errors.InvalidArgumentError 예외가 발생합니다. XLA는 현재 측정기준을 추론할 수 없는 즉, 전체 계산을 실행하지 않고는 모든 텐서의 측정기준을 추론할 수 없는 함수를 컴파일할 수 없습니다. 예를 들어 다음 함수는 컴파일되지 않습니다.

@tf.function
def not_compilable(x):
  return tf.unique(x)

하지만 도형은 실행마다 다를 수 있습니다.

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

자세한 사용 사례는 튜토리얼 Colab을 참고하고 jit_compile=True 사용 방법은 튜토리얼 동영상을 참고하세요.

자동 클러스터링

아무것도 변경하지 않고 TensorFlow 모델에서 XLA를 사용하기 시작하는 간단한 방법은 자동 클러스터링을 사용 설정하는 것입니다. 자동 클러스터링은 TensorFlow 함수 내에서 XLA를 사용하여 컴파일하고 실행할 수 있는 클러스터(연결된 하위 그래프)를 자동으로 찾습니다. 다음과 같이 TF_XLA_FLAGS 환경 변수를 설정하여 GPU에서 자동 클러스터링을 사용 설정할 수 있습니다.

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

자동 클러스터링은 현재 GPU 워크로드에 최적화되어 있지만 다음과 같이 --tf_xla_cpu_global_jit 플래그를 추가로 사용하여 CPU에서 사용 설정할 수도 있습니다.

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

자세한 사용 예는 자동 클러스터링 튜토리얼 Colab을 참고하세요.

tfcompile로 CPU의 AOT(Ahead-of-time) 컴파일

TensorFlow 그래프를 실행 코드로 변환하는 독립형 tfcompile 도구를 사용할 수도 있습니다(x86-64 CPU만 해당).

컴파일된 프로그램 검사

XLA는 생성된 프로그램을 검사할 수 있는 내부 검사 기능을 제공합니다. 생성된 프로그램을 덤프하려면 다음과 같이 XLA_FLAGS 환경 변수를 사용합니다.

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

덤프가 완료되면 /tmp/generated에서 다음 파일을 찾을 수 있습니다.

  • module_XXXX.*_optimizations.txt - 생성된 XLA 프로그램으로, 컴파일된 각 클러스터당 하나씩 있습니다. XLA 버그 신고를 제출할 때 첨부하면 매우 유용합니다.

  • module_XXXX.ir-*.ll - NVPTX 내장 기능과 함께 LLVM 중간 표현으로 생성된 파일입니다.

  • module_XXXX.ptx - 생성된 PTX 파일입니다.

또한 다음을 통해 TensorFlow 그래프 내부에 XLA 클러스터 임베딩을 시각화하는 그래프를 덤프할 수 있습니다.

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

재생성 가능한 버그 보고서

버그 보고서에 생성된 XLA 프로그램의 덤프 및 사용된 자동 클러스터링 임베딩이 포함되어 있다면 버그 보고서를 훨씬 더 쉽게 재생성할 수 있습니다. 자동 클러스터링으로 실행되는 TensorFlow 프로그램을 대상으로 버그 보고서를 생성하려면 다음을 실행합니다.

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

버그를 제출할 때 /tmp/generated 디렉터리(위 참고)의 콘텐츠를 첨부합니다.

가능하면 replay_computation을 사용하고 생성된 프로그램에서 이를 반복적으로 실행하여 버그를 단일 XLA 프로그램으로 분리하려고 해야 합니다.

추가 자료

XLA 프런트엔드

TensorFlow 외에 다음으로도 XLA 프로그램을 생성할 수 있습니다.

  • JAX: Python+NumPy 프로그램의 구성 가능한 변환
  • Julia: 과학 컴퓨팅을 위한 Julia 언어
  • PyTorch: PyTorch 프레임워크
  • Nx: Elixir 프로그래밍 언어를 위한 숫자 컴퓨팅 라이브러리

관련 발표

jit_compile=True로 TF에서 XLA 사용

XLA 개요