本文档将介绍 XLA 中的广播语义如何工作。
什么是广播?
广播是使具有不同形状的数组获得兼容形状以便进行算术运算的过程。这一术语取自 Numpy 广播。
在具有不同秩的多维数组之间,或者具有不同但兼容形状的多维数组之间执行运算时可能需要广播。请思考加法 X+v
,其中 X
为矩阵(秩为 2 的数组)、v
为向量(秩为 1 的数组)。要执行逐元素加法,XLA 需要对向量 v
进行“广播”,通过复制数次 v
来使它的秩与矩阵 X
的秩相同。向量的长度必须至少与矩阵的一个维度相匹配。
例如:
|1 2 3| + |7 8 9|
|4 5 6|
矩阵维度为 (2,3),向量维度为 (3)。通过复制一行向量来进行广播,得到:
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
在 NumPy 中,这种方法被称为广播。
原则
XLA 语言应尽可能严格而显式,避免隐式和“魔术”特征。此类特征可能会使一些计算的定义略为方便,但代价是要在用户代码中增加更多的假设,而长期则会难以更改。如有必要,可以在客户端级别的包装器中添加隐式和魔术特征。
关于广播,具有不同秩的数组之间的运算需要显式广播规范。这与有可能推断出规范的 NumPy 有所不同。
将低秩数组广播至高秩数组
标量始终可以在没有广播维度显式规范的情况下对数组进行广播。在标量和数组之间进行逐元素二元运算,表示对数组中的每个元素应用与标量的运算。例如,标量与矩阵之间的加法将生成一个矩阵,其中每个元素均为标量与相应输入矩阵元素的和。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
大多数广播需求都可以通过在二元运算中使用维度元组来捕获。当运算的输入具有不同的秩时,此广播元组可指定高秩数组中的哪个/哪些维度与低秩数组相匹配。
请思考上一示例,我们将标量与 2x3 矩阵相加改为将三维向量与 2x3 矩阵相加。如果不指定广播,此运算无效。要正确地请求矩阵-向量加法运算,请将广播维度指定为 (1),这表示向量的维度与矩阵的维度 1 相匹配。对于矩阵,如果将维度 0 视为行、维度 1 视为列,那么将广播维度指定为 (1) 就表示向量的每个元素各成一列,其大小与矩阵中的行数匹配:
|7 8 9| ==> |7 8 9|
|7 8 9|
作为更复杂的示例,请思考将 3 元素向量(维度为 (3))与 3x3 矩阵(维度为 (3,3))相加。此示例可以采用两种广播方式:
(1) 可以将广播维度指定为 1。每个向量元素各成一列,复制向量作为矩阵的每一行。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2) 可以将广播维度指定为 0。每个向量元素各成一行,复制向量作为矩阵的每一列。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
注:2x3 矩阵与 3 元素向量相加时,广播维度 0 无效。
广播维度可以是一个元组,用于描述如何将低秩形状广播至高秩形状。例如,给定 2x3x4 长方体和 3x4 矩阵,广播元组 (1,2) 表示将矩阵与长方体的维度 1 和 2 匹配。
在给定 broadcast_dimensions
参数的情况下,可以在 XlaBuilder
中的二元运算中使用此类广播。有关示例,请参阅 XlaBuilder::Add。在 XLA 源代码中,此类广播有时被称为“InDim”广播。
正式定义
广播特性支持将低秩数组与高秩数组进行匹配,方法是指定要匹配的高秩数组维度。例如,对于维度为 MxNxPxQ 的数组,T 维向量可以按如下方式进行匹配:
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
每种情况下,T 都必须等于高秩数组的匹配维度。然后将向量的值从匹配的维度广播至所有其他维度。
要将 TxV 矩阵匹配到 MxNxPxQ 数组,需要使用一对广播维度:
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
广播元组中的维度顺序必须符合低秩数组维度与高秩数组维度进行匹配的顺序。元组中的第一个元素对应于高秩数组中需要与低秩数组维度 0 相匹配的维度。第二个元素则对应于维度 1,依此类推。广播维度的顺序必须严格递增。例如,在上一示例中,将 V 匹配至 N 而将 T 匹配至 P 是非法的;同样,将 V 同时匹配至 P 和 N 也是非法的。
广播具有退化维度的相似秩数组
有一个相关的广播问题:广播两个具有相同的秩但维度大小不同的数组。与 NumPy 的规则类似,这种广播仅在数组兼容的情况下可行。当两个数组的所有维度全部兼容时,即表示这两个数组兼容。在以下情况下,两个维度兼容:
- 两者相等,或者
- 其中一者为 1(“退化”维度)
两个兼容的数组相遇时,所得形状在每个维度索引的两个输入之间取最大值。
例如:
- (2,1) 和 (2,3) 广播至 (2,3)。
- (1,2,5) 和 (7,2,5) 广播至 (7,2,5)。
- (7,2,5) 和 (7,1,5) 广播至 (7,2,5)。
- (7,2,5) 和 (7,2,6) 不兼容,因此无法广播。
有一种特殊情况同样支持广播,即每个输入数组在不同索引处均具有退化维度。在这种情况下,结果为“外积运算”:(2,1) 和 (1,3) 广播至 (2,3)。有关更多示例,请参阅与广播有关的 Numpy 文档。
广播组成
将低秩数组广播至高秩数组以及使用退化维度进行广播,这两种方式均可使用相同的二元运算执行。例如,可以使用广播维度值 (0),将大小为 4 的向量与 1x2 维矩阵相加。
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
首先,使用广播维度将向量广播至 2 秩(矩阵)。广播维度中的单个值 (0) 表示向量的零维度与矩阵的零维度相匹配。这将生成一个 4xM 维矩阵,所选 M 值应匹配 1x2 数组中相应的维度大小。因此,生成一个 4x2 矩阵:
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
然后,“退化维度广播”将广播 1x2 矩阵的零维度以匹配右侧的相应维度大小:
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
一个更为复杂的示例是,使用广播维度 (1, 2) 对 1x2 矩阵与 4x3x1 数组相加。首先,使用广播维度将 1x2 矩阵广播至 3 秩以生成中间 Mx1x2 数组,其中维度大小 M 由较大运算对象(4x3x1 数组)的大小确定,生成 4x1x2 中间数组。因为维度 1 和 2 由广播维度 (1, 2) 映射到原始 1x2 矩阵的维度,所以 M 的维度为 0(最左侧的维度)。可以使用退化维度广播对此中间数组与 4x3x1 矩阵相加,以生成 4x3x2 数组结果。