このドキュメントでは、XLA でセマンティクスのブロードキャストがどのように機能するかについて見ていきます。
ブロードキャストとは
ブロードキャストとは、さまざまな形状の配列に、算術演算と互換性のある形状を持たせるプロセスです。この用語は、Numpy ブロードキャストから借用しています。
異なるランクの多次元配列間、または形状が異なるが互換性のある多次元配列間の演算には、ブロードキャストが必要になる場合があります。加算 X+v
について考えてみましょう。ここで、 X
は行列(ランク 2 の配列)であり、v
はベクトル(ランク 1 の配列)です。要素ごとの加算を実行するには、XLA は、v
を特定の回数複製することにより、ベクトル v
を行列 X
と同じランクに「ブロードキャスト」する必要があります。ベクトルの長さは、行列の次元の少なくとも 1 つと一致する必要があります。
以下に例を示します。
|1 2 3| + |7 8 9|
|4 5 6|
行列の次元は (2,3)、ベクトルの次元は (3) です。ベクトルは、行に複製して次のようにブロードキャストされます。
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
Numpy では、これはブロードキャストと呼ばれます。
原則
XLA 言語は可能な限り厳密かつ明示的であり、暗黙的な機能や「魔法」の機能を回避します。このような機能により、一部の計算の定義が多少簡単になる場合がありますが、ユーザーコードに組み込まれる仮定が増え、長期的な変更が困難になることがあります。暗黙的な魔法の機能は必要に応じて、クライアントレベルのラッパーに追加できます。
ブロードキャストに関しては、異なるランクの配列間の演算に関する明示的なブロードキャスト仕様が必要です。これは、可能な場合に仕様を推測する Numpy とは異なります。
下位の配列を上位の配列にブロードキャストする
スカラーは、ブロードキャスト次元を明示的に指定しなくても、常に配列を介してブロードキャストできます。スカラーと配列の間の要素ごとの二項演算は、配列内の各要素にスカラーを使用した演算を適用することを意味します。たとえば、行列にスカラーを追加するということは、各要素が対応する入力行列の要素とスカラーの合計である行列を生成することを意味します。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
ほとんどのブロードキャストでは、二項演算で次元のタプルを使用します。演算への入力のランクが異なる場合、このブロードキャストタプルは、下位ランク配列と一致する上位ランク配列の次元を指定します。
前の例を考えてみましょう。(2,3) 行列にスカラーを追加する代わりに、次元 (2,3) の行列に次元 (3) のベクトルを加算します。ブロードキャストを指定しないと、この演算は無効になります。行列とベクトルの加算を正しく要求するには、ブロードキャストの次元を (1) に指定します。これは、ベクトルの次元が行列の次元 1 と一致することを意味します。2D では、次元 0 が行と見なされ、次元 1 が列と見なされる場合、これは、ベクトルの各要素が、行列の行数に一致するサイズの列になることを意味します。
|7 8 9| ==> |7 8 9|
|7 8 9|
より複雑な例として、3 要素のベクトル (次元 (3)) を3x3 行列 (次元 (3,3)) に加算することを検討します。この例では、ブロードキャストを実行する方法が 2 つあります。
(1) ブロードキャスト次元 1 を使用します。各ベクトル要素は列になり、ベクトルは行列の各行に複製されます。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2) ブロードキャスト次元 0 を使用します。各ベクトル要素は行になり、ベクトルは行列の各列に複製されます。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
注意: 2x3 行列を 3 要素ベクトルに追加する場合、ブロードキャスト次元 0 は無効です。
ブロードキャスの次元は、小さいランクの形状が大きいランクの形状にブロードキャストされる方法を説明するタプルになります。たとえば、2x3x4 の直方体と 3x4 の行列が与えられた場合、ブロードキャストタプル (1,2) は、行列を直方体の次元 1 と 2 に一致させることを意味します。
このタイプのブロードキャストは、broadcast_dimensions
引数が指定されている場合、XlaBuilder
のバイナリ演算で使用されます(例:XlaBuilder::Add)。XLAソースコードでは、このタイプのブロードキャストは「InDim」ブロードキャストと呼ばれることもあります。
正式な定義
ブロードキャスト属性を使用すると、上位配列のどの次元を照合するかを指定することにより、下位配列を上位配列に照合できます。たとえば、次元が MxNxPxQ の配列の場合、次元が T のベクトルは次のように一致させることができます。
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
いずれの場合も、T は上位の配列の一致する次元と等しくなければなりません。ベクトルの値は、一致した次元から他のすべての次元にブロードキャストされます。
TxV 行列を MxNxPxQ 配列に一致させるために、ブロードキャスト次元のペアを使用します。
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
ブロードキャストタプルの次元の順序は、下位の配列の次元が上位の配列の次元と一致すると予想される順序である必要があります。タプルの最初の要素は、上位配列のどの次元が下位配列の次元 0 と一致する必要があるかを示します。2 番目の要素は、次元 1 に一致する必要があります。ブロードキャスト次元の順序は厳密に増加する必要があります。たとえば、前の例では、V を N に、T を P に一致させることはできません。また、V を P と N の両方に一致させることもできません。
縮退した次元での同様のランクの配列のブロードキャスト
関連するブロードキャストの問題として、ランクが同じで次元サイズが異なる 2 つの配列をブロードキャストすることがあります。Numpy のルールと同様に、これは配列が互換性のあるの場合にのみ可能です。2 つの配列は、すべてのサイズに互換性がある場合に互換性があります。次の場合、2 つの次元に互換性があります。
- それらは等しいか、
- それらの 1 つが 1(「縮退」次元)
互換性のある 2 つの配列がある場合、結果の形状は、すべての次元インデックスで 2 つの入力の中で最大のものになります。
例:
- (2,1) と (2,3) は (2,3) にブロードキャスト
- (1,2,5) と (7,2,5) は (7,2,5) にブロードキャスト
- (7,2,5) と (7,1,5) は (7,2,5) にブロードキャスト
- (7,2,5) と (7,2,6) は互換性がないのでブロードキャストできない。
各入力配列が異なるインデックスで縮退した次元を持つという特殊なケースが発生した場合は、サポートされます。この場合、結果は「外部演算」になります。(2,1) と (1,3) は (2,3) にブロードキャストされます。その他の例については、ブロードキャストに関する Numpy のドキュメントを参照してください。
ブロードキャストの構成
下位配列から上位配列へのブロードキャストおよび縮退次元を使用したブロードキャストは、どちらも同じ 2 項演算で実行できます。たとえば、サイズ 4 のベクトルとサイズ 1x2 の行列は、ブロードキャスト次元値 (0) を使用して加算できます。
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
最初に、ベクトルはブロードキャスト次元を使用してランク 2 (行列) までブロードキャストされます。ブロードキャスト次元の単一の値 (0) は、ベクトルの次元ゼロが行列の次元ゼロと一致することを示します。これにより、サイズ 4xM の行列が生成されます。ここで、値 M は、1x2 配列の対応する次元サイズに一致するように選択されます。したがって、4x2 行列が生成されます。
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
次に、「縮退次元ブロードキャスト」は、右側の対応する次元サイズに一致するように、1x2 行列の次元ゼロをブロードキャストします。
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
より複雑な例は、(1, 2) のブロードキャスト次元を使用してサイズ 4x3x1 の配列にサイズ 1x2 の行列を加算する場合です。最初に、1x2 行列がブロードキャスト次元を使用してランク 3 までブロードキャストされ、中間 Mx1x2 配列が生成されます。次元サイズ M は、4x1x2 中間配列を生成する大きい方のオペランド(4x3x1 配列)のサイズによって決定されます。ブロードキャスト次元が (1, 2) であるため、次元 1 と 2 が元の 1x2 行列の次元にマップされるため、M は次元 0(左端の次元)にあります。この中間配列は、縮退した次元のブロードキャストを使用して 4x3x1 行列に加算し、4x3x2 配列の結果を生成します。