XLA:機器學習的最佳化編譯器

XLA (加速線性代數) 是特定領域專用的線性代數編譯器,不但可加快 TensorFlow 模型的速度,而且可能完全不需要變更原始碼,

讓執行速度和記憶體用量都有所改善。舉例來說,在 BERT 中,如果 MLPerf 提交內容使用 8 伏打 V100 GPU 和 XLA,則效能可提升約 7 倍,批量也可減少約 5 倍:

簡介

當 TensorFlow 程式執行時,TensorFlow 執行程式會個別執行所有作業。每個 TensorFlow 作業都有預先編譯的 GPU 核心實作,做為執行程式的分派目標。

XLA 提供了另一種執行模型的模式,將 TensorFlow 圖形編譯為一系列專為特定模型建立的運算核心。這些核心專屬於模型,因此可以利用模型特定的資訊進行最佳化。我們來看看 XLA 在簡單的 TensorFlow 運算中所執行的最佳化:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

如果在沒有 XLA 的情況下執行,圖形會啟動三個核心:分別用於乘法、加法及減法。但 XLA 可將圖形最佳化,因此只需啟動單一核心即可計算出結果。做法是將加法、乘法和減法「融合」成一個 GPU 核心。此外,這個融合作業不會將 y*zx+y*z 產生的中間值寫入記憶體,而是將這些中間運算的結果直接「串流」給使用者,並將這些結果完整保留在 GPU 寄存器中。融合是 XLA 最重要的最佳化方式。記憶體頻寬往往是硬體加速器最稀缺的資源。因此,改善效能的最佳方法之一就是移除記憶體運算。

為 TensorFlow 模型啟用 XLA

使用 tf.function(jit_compile=True) 進行明確編譯

明確編譯 API 可讓您更精細地掌控要選擇哪些函式進行編譯。例如,下列用於執行 MNIST 訓練的 TensorFlow 函式是使用 XLA 編譯而成:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

jit_compile API 具有 must-compile 語意:不是使用 XLA 編譯整個函式,就是擲回 errors.InvalidArgumentError 例外狀況。XLA 目前無法編譯維度無法推論的函式 (也就是無法在不執行完整運算的情況下推論所有張量的維度)。舉例來說,下列函式無法編譯:

@tf.function
def not_compilable(x):
  return tf.unique(x)

不過,形狀在不同執行過程中可能會有所變化:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

如需更詳細的使用範例,請參閱教學課程 colab,以及使用 jit_compile=True教學影片

自動分群法

如要在 TensorFlow 模型中開始使用 XLA,而不進行任何變更,最簡單的方法是啟用「自動分群法」,這項功能會在可使用 XLA 進行編譯和執行的 TensorFlow 函式中,自動尋找「叢集」(已連結的子圖)。您可以藉由設定 TF_XLA_FLAGS 環境變數,在 GPU 上啟用自動分群法:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

自動分群法目前最適合用於 GPU 工作負載,但您也可以另外使用 --tf_xla_cpu_global_jit 旗標,在 CPU 上啟用自動分群法:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

如需詳細的使用範例,請參閱自動分群法教學課程 Colab

使用 tfcompile 進行 CPU 的 AOT (預先) 編譯

您也可以使用獨立的 tfcompile 工具,將 TensorFlow 圖轉換為可執行的程式碼 (僅支援 x86-64 CPU)。

檢查已編譯的程式

XLA 提供自我檢查設施,可讓你檢查產生的程式。如要傾印產生的程式,請使用環境變數 XLA_FLAGS

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

執行傾印後,您可以在 /tmp/generated 中找到下列檔案:

  • module_XXXX.*_optimizations.txt 產生的 XLA 程式,每個已編譯的叢集各有一個。在提交 XLA 錯誤報告時附上這些程式將會非常有幫助!

  • module_XXXX.ir-*.ll 產生的 LLVM 中繼表示法檔案,包含 NVPTX 內建函式。

  • module_XXXX.ptx 產生的 PTX 檔案。

你也可以使用下列程式碼,來傾印以視覺化的方式呈現 TensorFlow 圖形中 XLA 叢集嵌入的圖形:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

可重現的錯誤報告

如果錯誤報告包含所產生 XLA 程式的傾印,以及使用的自動分群法嵌入,就會更容易重現。 如要為使用自動分群法執行的 TensorFlow 程式產生這些資訊,請啟動:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

提報錯誤時,請附上 /tmp/generated 目錄的內容 (如上所示)。

請盡可能使用 replay_computation 將錯誤隔離到單一 XLA 程式,並透過產生的程式疊代執行該程式。

延伸閱讀

XLA 前端

除了 TensorFlow 以外,您還可以透過下列方式產生 XLA 程式:

  • JAX:Python + NumPy 程式的可組合轉換
  • Julia:用於科學運算的 Julia 語言
  • PyTorch:PyTorch 架構
  • Nx:Elixir 程式語言的數值運算庫

談話

使用 jit_compile=True 來透過 TF 運用 XLA

XLA 總覽