注意: ここで説明するタイルレイアウトはプレリリース機能であり、エラーは無視できます。
図 1
図 1 は、配列 F32 [3,5] が 2x2 タイリングでメモリにどのように配置されるかを示しています。このレイアウトの形状は、F32[3,5]{1,0:T(2,2)} として記述されます。ここで、1,0 は物理的な次元の順序(レイアウトの minor_to_major フィールド)に関連し、コロンの後の (2,2) は、2x2 タイルによる物理的次元のタイリングを示します。
直感的にタイルが形状をカバーするようにレイアウトされ、各タイル内で、上記の例のように要素がタイルなしでレイアウトされます。例の右側の部分は、メモリ内のレイアウトを示しています。元の配列の境界が均一でなくても、レイアウトには完全な 2x2 のタイル型にするために白いパディング要素が追加されます。
パディングの追加要素には、特定の値を含める必要はありません。
形状とタイルが与えられた場合のタイリングの線形インデックス式
タイリングなしの場合、配列境界 d =(dn, dn-1, ... , d1) (d1 は最も小さい次元)の配列内の要素 e =(en, en-1, ... , e1) は、次の位置にメジャーからマイナーな順序で配置されます。
linear_index(e, d)
= linear_index((en, en-1, ... , e1), (dn, dn-1, ... , d1))
= endn-1...d1 + en-1dn-2...d1 + ... + e1
ここでは表記を簡単にするために、タイルの次元数は配列と同じであると想定します。XLA のタイリングの実装では、これは、最初の最もメジャーな次元を変更せず、最もマイナーな次元にのみタイリングを適用することにより、より少ない次元のタイリングに一般化されます。そのため、指定されたタイリングには、タイリングされている形状の物理的次元の接尾辞がついています。
サイズ (tn, tn-1, ... , t1) のタイリングが使用される場合、インデックス (en, en-1, ... , e1) の配列内の要素は以下のように最終的なレイアウトの位置にマップされます。
linear_index_with_tile(e, d, t)
= linear_index((⌊e/t⌋, e mod t), (⌈d/t⌉, t)) (算術は要素ごと、(a,b) は連結)
= linear_index((⌊en/tn⌋, ... , ⌊e1/t1⌋, en mod tn, ... , e1 mod t1), (⌈dn/tn⌉, ... , ⌈d1/t1⌉, tn, tn-1, ... , t1))
= linear_index((⌊en/tn⌋, ... , ⌊e1/t1⌋), (⌈dn/tn⌉, ... , ⌈d1/t1⌉))∙tntn-1...t1 + linear_index((en mod tn, ... , e1 mod t1), (tn, tn-1, ... , t1))
レイアウトは、次の 2 つの部分で構成されていると考えます。(⌊en/tn⌋, ... , ⌊e1/t1⌋) は、サイズ(⌈dn/tn⌉, ... , ⌈d1/t1⌉) のタイルの配列内のタイルインデックスに対応します。(en mod tn, ... , e1 mod t1) はタイル内インデックスに対応します。ceil 関数は⌈di/ti⌉ に表示されます。これは、タイルが大きな配列の境界を超えた場合、図 1 のようにパディングが挿入されるためです。タイルとタイル内の要素は、タイリングせずに再帰的にレイアウトされます。
図 1 の例では、要素 (2,3) には、結合された座標ベクトル (1, 1, 0, 1) に対して、タイルインデックス (1,1) とタイル内インデックス (0,1) があります。タイルインデックスには境界 (2, 3) があり、タイル自体は結合ベクトル (2, 3, 2, 2) に対して (2, 2) です。論理形状のインデックス (2, 3) を持つ要素のタイル付き線形インデックスは次のようになります。
linear_index_with_tile((2,3), (3,5), (2,2))
= linear_index((1,1,0,1), (2,3,2,2))
= linear_index((1,1), (2,3)) ∙ 2 ∙ 2 + linear_index((0,1), (2,2))
= (1 ∙ 3 + 1) ∙ 2 ∙ 2 + (0 ∙ 2 + 1)
= 17.
pad-reshape-transpose としてのタイリング
タイリングベースのレイアウトは次のように動作します。
次元 (dn, dn-1, ... , d1) の配列を考えてみます(d1 は最もマイナーな次元)。サイズ (tn, tn-1, ... , t1) (t1 は最もマイナーな次元) のタイリングでレイアウトされている場合、そのタイリングは、次のように pad-reshape-transpose の観点から説明できます。
- 配列は (⌈dn/tn⌉∙tn, ... , ⌈d1/t1⌉∙t1) にパッディングされます。
- 各次元 i は (⌈di/ti⌉, ti) に分割されます。つまり、配列は次のように再形状されます。
(⌈dn/tn⌉, tn, ... , ⌈d1/t1⌉, t1)。
この再形状自体には物理的なレイアウトの変更はないため、この再形状はビットキャストです。タイルを明示的に考えていない場合、この再形状は、パッド付きの形状と同じ数の要素を持つ任意の形状を表現できます。ここでの例は、この方法でタイルを表現する方法です。 - 転置は、tn, ... , t1を相対的な順序を維持しながら、最も小さい次元に移動することによって行われます。最も大きい次元から最も小さい次元への順序は
(⌈dn/tn⌉, ... , ⌈d1/t1⌉, tn, ... , t1) になります。
最終的な形状には接頭辞
(⌈dn/tn⌉, ... , ⌈d1/t1⌉) があります。これは、各次元のタイル数を表します。配列内の要素 (en, ... , e1) は、最終的な形状
(⌊en/tn⌋, ... , ⌊e0/t0⌋, en mod tn, ... , e1 mod t1)でこの要素にマップされます。要素の線形インデックスが予想どおり上記の式に従っていることが明らかです。
繰り返しタイリング
XLA のタイリングは、繰り返し適用することでさらに柔軟になります。
図 2
図 2 は、サイズ 4x8 の配列が 2 レベルのタイリング(最初に 2x4、次に 2x1)によりどのようにタイリングされるかを示しています。この繰り返されるタイリングを (2,4)(2,1) として表します。各色は 2x4 タイルを示し、それぞれの赤い境界ボックスは 2x1 タイルです。数字は、タイル形式の要素のメモリ内の線形インデックスを示します。この形式は、最初のタイルが大きいことを除いて、TPU の BF16 に使用される形式と一致します。つまり、タイリングは (8,128)(2,1) であり、2x1 による 2 番目のタイリングの目的は 2 つの 16 ビット値を収集し、TPU のアーキテクチャに合わせて 1 つの 32 ビット値を形成することです。
2 番目以降のタイルは、タイル間のマイナーな次元を両方を参照できることに注意してください。これは、この例の (8,128)(2,1) のように、タイル間のデータを再配置するだけですが、前のタイリングからのメジャーなタイル間の次元を参照することもできます。
タイルを使用した次元の組み合わせ
XLA のタイリングは、次元の組み合わせもサポートしています。たとえば、F32[2,7,8,11,10]{4,3,2,1,0} の次元を最初に F32[112,110]{1,0} と組み合わせてから、(2,3) とタイリングできます。使用するタイルは (∗,∗,2,∗,3) です。ここで、タイルのアスタリスクは、その次元を取得し、それを次に小さい次元と組み合わせることを意味します。複数の隣接する次元は、1 つの次元にまとめることができます。包含された次元は、タイルのその次元の -1 のタイル値で表されますが、次元サイズとしてのタイルでは無効です。
より正確には、形状の次元 i がタイルのアスタリスクを介して削除された場合、事前のタイリングの定義が適用される前に、その次元はタイル化されている形状とタイルベクトルの両方から削除されます。形状の次元 i-1 は、配列の境界が di-1 から didi-1 に増加しています。この手順は、タイルベクトルのアスタリスクごとに繰り返されます。