这是 DeepMind Scaling Book 系列的第 3 部分。
分布式矩阵及其乘法 (Sharded Matrices and How to Multiply Them)
How To Scale Your Model Part 3 (Part 2: TPUs | Part 4: Transformer Math)
当我们训练大型机器学习模型时,我们必须将它们的参数或输入拆分(或“分片”)到许多加速器上。由于 LLM 主要由矩阵乘法组成,理解这一点归结为理解当矩阵分布在不同设备上时如何进行乘法运算。我们基于 TPU 通信原语的成本,建立了一套简单的分片矩阵乘法理论。
分区符号和集体操作 (Partitioning Notation and Collective Operations)
当我们在 10,000 个 TPU 或 GPU 上训练 LLM 时,从抽象上讲,我们所做的计算与在一个设备上训练时是一样的。区别在于我们的数组无法放入单个 TPU/GPU 的 HBM(高带宽内存)中,所以我们必须拆分它们1。我们称之为“分片 (sharding)”或“分区 (partitioning)”我们的数组。扩展的艺术在于弄清楚如何分片我们的模型,以便计算保持高效。
这是一个分片在 4 个 TPU 上的 2D 数组 A 的示例:
Figure: 一个形状为 A[I, J] 的示例数组分片到 4 个设备上。两个维度都均匀分片到 2 个设备上,分片为 A[IX, JY]。每个 TPU 持有总内存的 1/4。
注意分片数组仍然具有相同的全局或逻辑形状,比如 (4, 128),但它也有一个 设备本地形状,比如 (2, 64),这给出了每个 TPU 持有的实际字节大小。现在我们将把它推广到任意数组。
一种统一的分片符号 (A unified notation for sharding)
我们使用命名轴符号 (named-axis notation) 的变体来描述张量是如何在设备上分块的:我们假设存在称为 设备网格 (device mesh) 的设备 2D 或 3D 网格,其中每个轴都被赋予了 网格轴名称,例如 X, Y, 和 Z。然后,我们可以通过描述数组的每个命名维度如何在物理网格轴上分区来指定矩阵数据如何在设备网格上布局。我们称这种分配为 分片 (sharding)。
例子(上面的图表):对于上面的图表,我们有:
- Mesh: 上面的设备网格
Mesh(devices=((0, 1), (2, 3)), axis_names=('X', 'Y')),这告诉我们有一个 2x2 网格中的 4 个 TPU,轴名称为 $X$ 和 $Y$。 - Sharding: $A[I_X, J_Y]$,这告诉我们将第一个轴 $I$ 沿网格轴 $X$ 分片,第二个轴 $J$ 沿网格轴 $Y$ 分片。这种分片告诉我们每个分片持有数组的 $1 / (\lvert X\rvert \cdot \lvert Y\rvert)$。
综上所述,我们要知道数组的本地形状(单个设备持有的分片大小)是 $(\lvert I\rvert / 2, \lvert J\rvert / 2)$,其中 $\lvert I\rvert$ 是 A 的第一个维度的大小,$\lvert J\rvert$ 是 A 的第二个维度的大小。
Pop Quiz [2D sharding across 1 axis]: 考虑一个形状为 fp32[1024, 4096] 的数组,分片为 $A[I_{XY}, J]$,网格为 {'X': 8, 'Y': 2}。每个设备持有多少数据?在 H100 上从 HBM 加载这个数组需要多长时间(假设每芯片 3.4e12 内存带宽)?
点击这里查看答案。
$A[I_{XY}, J]$ 将第一维 (I) 沿 X 和 Y 硬件轴分片。在这个例子中,本地形状是 $(\lvert I\rvert /(\lvert X\rvert \cdot \lvert Y\rvert), \lvert J\rvert)$。对于给定的例子,全局形状是 fp32[1024, 4096],所以本地形状是 fp32[64, 4096]。
由于每个 GPU 拥有 4 * 64 * 4096 = 1MiB 字节,这大约需要 1e6 / 3.4e12 = 294ns,尽管由于各种开销,实际上可能会多得多,因为它太小了。
可视化这些分片: 让我们通过查看分片到 4 个设备上的 2D 数据数组来尝试可视化这些分片:

我们将矩阵的 完全复制 (fully-replicated) 形式简单地写为 $A[I, J]$,没有分片分配。这意味着 每个 设备都包含整个矩阵的完整副本。

我们可以用下标网格轴表示这些维度之一已跨网格轴分区。例如 $A[I_X, J]$ 意味着 I 逻辑轴已跨 X 网格维度分区,但 J 维度 未 分区,并且块在 Y 网格轴上保持 部分复制。

$A[I_X, J_Y]$ 意味着 I 逻辑轴已跨 X 网格轴分区,并且 J 维度已跨 Y 网格轴分区。

我们在下图中说明了其他可能性:

这里 $A[I_{XY}, J]$ 意味着我们将 X 和 Y 网格轴视为一个更大的扁平维度,并将 I 命名轴跨所有设备分区。多个网格轴下标的顺序很重要,因为它指定了跨网格的分区遍历顺序。

最后,注意我们 不能 有多个命名轴沿 同一 网格维度分片。例如 $A[I_X, J_X]$ 是一种荒谬的、禁止的分片。一旦网格维度用于分片数组的一个维度,它在某种意义上就被“花费”了。
Pop Quiz: 设 A 是一个形状为 int8[128, 2048] 的数组,分片为 $A[I_{XY}, J]$,网格为 Mesh({‘X': 2, ‘Y': 8, ‘Z': 2})(总共 32 个设备)。A 每个设备使用多少内存?A 在所有设备上总共使用多少内存?
点击这里查看答案。
答案: 我们的数组 A 在 X 和 Y 上分片,在 Z 上复制,所以每设备的形状为 int8[128 / (2 * 8), 2048] = int8[8, 2048],大小为 8 * 2048 = 16,384 字节。因为它是跨 Z 复制的,而在 Z 平面内它是跨 X 和 Y 完全分片的,所以有 2 个原始数组的完整副本(每个 Z 平面一个)。所以所有设备的总大小是:原始数组大小 × Z 副本数 = 128 * 2048 * 2 = 512 KiB 总计。或者,我们可以验证为:32 个设备 × 16,384 字节/设备 = 512 KiB 总计。
我们如何在代码中描述这一点? (How do we describe this in code?)
到目前为止,我们避免谈论代码,但现在是一个很好的先睹为快的机会。JAX 使用一种命名分片语法,非常接近我们上面描述的抽象语法。我们将在 第 10 章 中更多地讨论这个问题,但这里有一个快速预览。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import jax
import jax.numpy as jnp
# 创建我们的 mesh!我们在一个名字为 'X' 和 'Y' 的 TPU v2-8 4x2 切片上运行。
assert len(jax.devices()) == 8
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
# 一个小工具函数来帮助定义我们的分片。PartitionSpec 是我们的
# 分片(从轴到名称的映射)。
def P(*args):
return jax.NamedSharding(mesh, jax.sharding.PartitionSpec(*args))
# 我们将 A 和 B 都在非收缩维度上分片,将 A 在收缩维度上分片。
A = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
B = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P(None, 'Y'))
# 我们可以在这些分片数组上执行 matmul!out_shardings 告诉我们希望
# 输出如何分片。JAX/XLA 为我们处理其余的分片。
y = jax.jit(lambda A, B: jnp.einsum('BD,DF->BF', A, B), out_shardings=P('X', 'Y'))(A, B)
JAX 最酷的地方在于这些数组表现得好像它们没有分片一样!B.shape将告诉我们全局或逻辑形状 (2048, 8192)。我们必须实际查看 B.addressable_shards 才能看到它是如何在本地分片的。
分片数组的计算 (Computation With Sharded Arrays)
如果你有一个分布在许多设备上的数据数组,并希望对其执行数学运算,那么分片数据和计算的相关开销是多少?
显然,这取决于所涉及的计算。
- 对于 逐元素 操作,对分布式数组进行操作 没有开销。
- 当我们希望跨许多设备上的元素执行操作时,事情变得复杂。值得庆幸的是,对于大多数机器学习,几乎所有计算都以矩阵乘法的形式发生,并且它们相对容易分析。
本节的其余部分将讨论如何乘以分片矩阵。作为一阶近似,这涉及移动矩阵块,以便你可以完全乘以或求和每个块。每种分片将涉及不同的通信。 例如,$A[I_X, J] \cdot B[J, K_Y] \to C[I_X, K_Y]$ 可以在没有任何通信的情况下相乘,因为 收缩维度 (J,我们实际求和的那个) 是未分片的。但是,如果我们想要输出未分片(即 $A[I_X, J] \cdot B[J, K_Y] \to C[I, K]$),我们需要将 $A$ 和 $B$ 或 $C$ 复制到每个设备(使用 AllGather)。这两种选择有不同的通信成本,所以我们需要计算这个成本并选择最低的一个。
你可以从“分块矩阵乘法”的角度来思考这个问题。
要理解这一点,回忆“分块矩阵”的概念是很有帮助的,即矩阵的嵌套矩阵:
\[\begin{pmatrix} a_{00} & a_{01} & a_{02} & a_{03} \\ a_{10} & a_{11} & a_{12} & a_{13} \\ a_{20} & a_{21} & a_{22} & a_{23} \\ a_{30} & a_{31} & a_{32} & a_{33} \end{pmatrix} = \left( \begin{matrix} \begin{bmatrix} a_{00} & a_{01} \\ a_{10} & a_{11} \end{bmatrix} \\ \begin{bmatrix} a_{20} & a_{21} \\ a_{30} & a_{31} \end{bmatrix} \end{matrix} \begin{matrix} \begin{bmatrix} a_{02} & a_{03} \\ a_{12} & a_{13} \end{bmatrix} \\ \begin{bmatrix} a_{22} & a_{23} \\ a_{32} & a_{33} \end{bmatrix} \end{matrix} \right) = \begin{pmatrix} \mathbf{A_{00}} & \mathbf{A_{01}} \\ \mathbf{A_{10}} & \mathbf{A_{11}} \end{pmatrix}\]矩阵乘法的一个很好的性质是,当矩阵乘数以块的形式写出时,乘积可以按照标准规则写成块 matmuls 的形式:
\[\begin{pmatrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{pmatrix} \cdot \begin{pmatrix} B_{00} & B_{01} \\ B_{10} & B_{11} \end{pmatrix} = \begin{pmatrix} A_{00}B_{00} + A_{01}B_{10} & A_{00}B_{01} + A_{01}B_{11} \\ A_{10}B_{00} + A_{11}B_{10} & A_{10}B_{01} + A_{11}B_{11} \end{pmatrix}\]这意味着实现分布式矩阵乘法归结为在网络上移动这些分片块,对这些块执行 本地 矩阵乘法,并对结果求和。问题于是变成了要添加什么通信,以及它有多昂贵。
方便的是,我们可以将所有可能的分片归结为我们需要考虑的大约 4 种情况,每种情况都有关于我们需要添加什么通信的规则。
- Case 1: 两个输入都没有沿收缩维度分片。我们可以无通信地乘以本地分片。
- Case 2: 一个输入有分片的收缩维度。我们通常沿收缩维度 “AllGather” 分片输入。
- Case 3: 两个输入都沿收缩维度分片。我们可以乘以本地分片,然后 “AllReduce” 结果。
- Case 4: 两个输入都有沿同一轴分片的非收缩维度。我们在不先 AllGather 其中一个输入的情况下无法进行。
你可以把这些看作是只需遵循的规则,但理解为什么这些规则成立以及它们有多昂贵也是很有价值的。我们将详细介绍每一种情况。
Case 1: 两个乘数都没有分片的收缩维度
引理: 当乘以分片矩阵时,计算是有效的,输出遵循输入的分片,除非 收缩维度被分片或两个矩阵沿同一轴分片。例如,这可以正常工作:
\[\mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow \mathbf{C}[I_X, K_Y]\]没有任何通信,并产生跨 X 和 Y 硬件维度分片的张量。试着想一想为什么会这样。基本上,计算与分片 无关,因为每个 batch 条目都有一些轴的本地块,它可以乘法和归约。任何这些情况都可以正常工作并遵循此规则:
\[\begin{align*} \mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I, K] \\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I_X, K]\\ \mathbf{A}[I, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I, K_Y]\\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I_X, K_Y] \end{align*}\]因为 A 和 B 都没有分片的收缩维度 J,我们可以简单地执行输入的本地块矩阵乘法,结果将 已经 按照所需的输出分片进行分片。
Case 2: 一个乘数有分片的收缩维度
让我们考虑当一个输入 A 沿收缩维度 J 分片,而 B 完全复制时该怎么办:
\[\mathbf{A}[I, J_X] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]\]我们不能简单地乘以 A 和 B 的本地块,因为我们需要对 A 的整个收缩维度求和,该维度被拆分到 X 轴上。通常,我们首先 “AllGather” A 的分片,以便每个设备都有完整的副本,然后才与 B 相乘:
\[\textbf{AllGather}_X[I, J_X] \rightarrow \mathbf{A}[I, J]\] \[\mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]\]这样,实际的乘法可以在每个设备上完全完成。
Takeaway: 当乘以矩阵时,如果其中一个矩阵沿收缩维度分片,我们通常先 AllGather 它,这样收缩不再分片,然后做本地 matmul。
注意,当 B 没有也沿 X 分片时,我们也可以做本地部分 matmul,然后求和(或 AllReduce)分片的部分和,这在某些情况下可能更快。见下面的 问题 4。
什么是 AllGather? AllGather 是我们将讨论的第一个核心 MPI 通信原语。AllGather 移除 轴上的分片,并将分布在设备上的分片重新组装到该轴上的 每个 设备上。
\[\textbf{AllGather}_{XY}(A[I_{XY}, J]) \rightarrow A[I, J]\]我们不必移除给定维度的所有下标,例如 \(A[I_{XY}, J] \rightarrow A[I_Y, J]\) 也是一个 AllGather,只是在一个轴上。还请注意,我们也可能希望使用 AllGather 来移除 非收缩 维度分片,例如在矩阵乘法中:
\[A[I_X, J] \cdot B[J, K] \rightarrow C[I, K]\]我们可以最初 AllGather A 以移除输入分片,或者我们可以做分片 matmul 然后 AllGather 结果 C。
AllGather 实际上是如何执行的? 为了围绕单个 TPU 轴(一个环)执行 1 维 AllGather,我们基本上让每个 TPU 沿环传递其分片,直到每个设备都有副本。
Figure: 一个动画,显示了如何在 8 个 TPU 或 GPU 设备组周围执行 AllGather。
我们可以单向或双向执行 AllGather(上面显示的是两个方向)。如果我们做一个方向,每个 TPU 通过 $N - 1$ 跳发送大小为 $\text{bytes} / N$ 的块。如果我们做两个方向,我们有 $\lfloor \frac{N}{2} \rfloor$ 跳,大小为 $2 \cdot \text{bytes} / N$。
这需要多长时间? 让我们以双向 AllGather 为例计算它需要多长时间。设 $V$ 为数组中的字节数,X 为收缩维度上的分片数。那么从上面的图表看,每一跳在每个方向发送 $V / \lvert X\rvert$ 字节,所以每一跳花费:
\[T_{hop} = \frac{2 \cdot V}{X \cdot W_\text{ici}}\]其中 $W_\text{ici}$ 是 双向 ICI 带宽。我们需要发送总共 $\lvert X\rvert / 2$ 跳来到达每个 TPU,所以总归约花费:
\[T_{total} = \frac{2 \cdot V \cdot X}{2 \cdot X \cdot W_\text{ici}}\] \[T_{total} = \frac{V}{W_\text{ici}}\]注意这 不依赖于 X! 这有点惊人,因为这意味着即使我们的 TPU 只是本地连接,连接的局部性并不重要。我们只是受限于每个链路的速度。
Takeaway: 当在吞吐量受限的机制中执行 AllGather(或 ReduceScatter 或 AllReduce)时,实际通信时间仅取决于数组的大小和可用带宽,而不取决于我们的数组分片到的设备数量!
关于 ICI 延迟的说明: 无论数据量如何,ICI 链路上的每一跳都有一些固有的开销。这通常约为 1us。这意味着当我们的数组 $A$ 非常小且每一跳花费少于 1us 时,我们可以进入一个“延迟受限”机制,此时计算 确实 依赖于 X。
点击这里查看完整细节。
设 $T_\text{min}$ 为单跳的最短时间。那么
\[T_{hop} = \max \left[ T_{min}, \frac{2 \cdot V}{X \cdot W_\text{ici}} \right]\] \[T_{total} = \max \left[ \frac{T_{min} \cdot X}{2}, \frac{V}{W_\text{ici}} \right]\]因为我们执行 $X / 2$ 跳。对于大的归约或 gathers,我们完全是带宽受限的。我们发送的数据如此之多,以至于每一跳的开销基本上可以忽略不计。但是对于小数组(例如从模型采样时),这是不可忽略的,ICI 带宽也不相关。我们纯粹受延迟限制。
这是一个在 TPU v5e 8x16 切片上的 AllGather 带宽的实证测量。

注意,我们不仅达到了大约 95% 的声称峰值带宽 (4.5e10),而且我们在大约 10MB 时达到了这个峰值,当 16 路分片时,这意味着每设备大约 500kB(旁注:这比 GPU 好得多)。
当我们跨多个轴 AllGather 时会发生什么? 当我们跨多个轴 gather 时,我们有多个 ICI 维度可以执行 gather。例如,AllGatherXY([B, DXY]) 在两个硬件网格轴上操作。这将可用带宽增加了 $N_\text{axes}$ 倍。
当考虑延迟时,我们最终得到一般规则:
\[T_{total} = \max \left[ \frac{T_{min} \cdot \sum_{i} |X_i|}{2}, \frac{V}{W_\text{ici} \cdot N_\text{axes}} \right]\]Pop Quiz 2 [AllGather time]: 使用 Part 2 中的数字,在具有 2D 网格 {'X': 8, 'Y': 4} 的 TPUv5e 上执行 AllGatherY([EY, F]) → [E, F] 需要多长时间?$E = 2048$, $F = 8192$ (bfloat16)?如果是 $E=256, F=256$ 呢?
点击这里查看答案。
答案: 让我们从计算一些基本量开始:
1) TPU v5e 对于其 2 个轴中的每一个都有 4.5e10 bytes/s 的单向 ICI 带宽。 2) 在 bfloat16 中,对于 (a),我们有 $A[E_Y, F]$,所以每个设备持有形状为 bfloat16[512, 8192] 的数组,即 512 * 8192 * 2 = 8.4MB。总数组大小为 2048 * 8192 * 2 = 34MB。
对于部分 (1),我们可以使用上面的公式。由于我们在一个轴上执行 AllGather,我们有 $T_{\text{comms}} = \text{34e6} / \text{9e10} = \text{377us}$。
对于部分 (2) 每个分片大小为 64 * 256 * 2 = 32kB。 32e3 / 4.5e10 = 0.7us,所以我们是延迟受限的。因为我们有 3 跳,这将大约需要 3 * 1us = 3us。
Case 3: 两个乘数都有分片的收缩维度
第三种基本情况是当两个乘数都在其收缩维度上分片,沿同一个网格轴:
\[\textbf{A}[I, J_X] \cdot \textbf{B}[J_X, K] \rightarrow C[I, K]\]在这种情况下,本地 分片块矩阵乘法至少是 可能 执行的,因为它们将共享相同的收缩索引集。但是每个乘积只代表最终所需乘积的 部分和,并且沿 X 维度的每个设备将留下此最终所需乘积的不同的 部分和。这非常常见,以至于我们扩展了我们的符号来明确标记这种情况:
\[\textbf{A}[I, J_X] \cdot_\text{LOCAL} \textbf{B}[J_X, K] \rightarrow C[I, K] \{\ U_X \}\]符号 { UX } 读作“沿 X 网格轴 未归约 (unreduced)”,指的是操作的这种状态在某种意义上是“不完整的”。
这可以看作是关于矩阵乘法和外积的以下结果:
\[A \cdot B = \sum_{i=1}^{P} \underbrace{A_{:,i} \otimes B_{i,:}}_{\in \mathbb{R}^{n \times m}}\]其中 ⊗ 是外积。因此,如果轴 X 上的 TPU i 有 A 的第 i 列和 B 的第 i 行,我们可以做一个本地矩阵乘法来获得 \(A_{:,i} \otimes B_{i,:} \in \mathbb{R}_{n\times m}\)。这个矩阵在每个条目中都有 A • B 在该条目处的和的第 i 项。我们仍然需要对 P 执行求和,我们在网格轴 X 上分片了它,以获得完整的 A • B。
我们可以使用跨 X 轴的完整 AllReduce 来补救这一点:
\[\begin{align*} A[I, J_X] \cdot_\text{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] \{ U_X \} \\ \textbf{AllReduce}_X C[I, K] \{ U_X \} \rightarrow &\ C[I, K] \end{align*}\]AllReduce 移除部分和,导致沿轴的 每个 设备具有相同的完全求和的值。AllReduce 接受一个在某个轴上未归约(部分求和)的数组,并通过传递这些分片并累积结果来执行求和。
\[\textbf{AllReduce}_Y A[I_X, J] \{U_Y\} \rightarrow A[I_X, J]\]AllReduce 有多贵? 一个 AllReduce 执行的心理模型是每个设备将其分片发送给邻居,并求和它收到的所有分片。显然,这比 AllGather 贵,因为每个“分片”与完整数组有相同的形状。通常,AllReduce 是 AllGather 成本的两倍。 看出这一点的一种方法是注意 AllReduce 可以表示为另外两个原语的组合:ReduceScatter 和 AllGather。
\[\begin{align*} \textbf{ReduceScatter}_{Y,J} : A[I_X,J] \{U_Y\} \rightarrow &\ A[I_X, J_Y] \\ \textbf{AllGather}_Y : A[I_X, J_Y] \rightarrow &\ A[I_X, J] \end{align*}\]什么是 ReduceScatter? 就像 AllReduce 移除下标一样,ReduceScatter 对未归约/部分求和的数组求和,然后沿同一网格轴分散(分片)不同的逻辑轴。$[F]\{U_Y\} \to [F_Y]$。动画展示了这是如何完成的:注意它与 AllGather 非常相似,但不是保留每个分片,而是将它们加在一起。因此,它的延迟大致相同。

每一跳的通信时间只是每分片字节数 $V / Y$ 除以带宽 $W_\text{ici}$,就像 AllGather 一样,所以我们有
\[T_{\text{comms per AllGather or ReduceScatter}} = \frac{V}{W_\text{ici}}\] \[T_{\text{comms per AllReduce}} = 2 \cdot \frac{V}{W_\text{ici}}\]Case 4: 两个乘数都有沿同一轴分片的非收缩维度
每个网格维度在分片张量时最多只能出现一次。执行上述规则有时会导致违反此规则的情况,例如:
\[A[I_X, J] \cdot B[J, K_X] \rightarrow C[I_X, K_X]\]这是无效的,因为沿维度 X 的给定分片,比如 i,将具有 C 的 (i, i) 分片,即对角线项。那么在所有分片中没有足够的信息来恢复结果除了对角线项以外的任何内容,所以我们不能允许这种分片。
解决这个问题的方法是 AllGather 其中一些维度。这里我们有两个选择:
\[\begin{align*} \textbf{AllGather}_X A[I_X, J] \rightarrow &\ A[I, J] \\ A[I, J] \cdot B[J, K_X] \rightarrow &\ C[I, K_X] \end{align*}\]或
\[\begin{align*} \textbf{AllGather}_X B[J, K_X] \rightarrow &\ B[J, K] \\ A[I_X, J] \cdot B[J, K] \rightarrow &\ C[I_X, K] \end{align*}\]在任何一种情况下,结果的形状中只会提到 X 一次。
深入了解 TPU 通信原语 (A Deeper Dive into TPU Communication Primitives)
之前的 4 种情况介绍了几个用于执行分片矩阵乘法的“核心通信原语”:
- AllGather: 移除分片下标,收集分片。
- ReduceScatter: 移除“未归约”后缀,通过在该轴上对分片求和,将数组沿第二个轴分片。
- AllReduce: 移除“未归约”后缀,使数组在该轴上未分片。
还有一个核心通信原语需要提及,它出现在混合专家 (MoE) 模型和其他计算中:AllToAll。
我们最后的通信原语:AllToAll
最后一个基本的集体操作是 AllToAll,或者更准确地说是 分片转置 或重新分片操作的特例。例如:
\[\textbf{AllToAll}_{X, J} A[I_X, J] \rightarrow A[I, J_X]\]AllToAll 通常需要在不同的分片布局方案之间重新排列分片布局。当考虑分片的 MoE 模型时,它们自然会出现。你可以把 AllToAll 想象成将下标从一个轴移动到另一个轴。因为 AllToAll 不需要在环上复制每个分片的所有数据,它实际上比 AllGather 便宜(便宜 1/4)2。

如果我们推广到 ND AllToAll,对于 AxBxC 网格上的 $V$ 字节数组,总成本为
\[T_\text{comms per AllToAll} = \frac{V \cdot \max(A, B, C, ...)}{4 \cdot N \cdot W_\text{ici}}\]其中 $W_\text{ici}$ 是通常的双向 ICI 带宽。对于 1D 网格,这减少到 $V / (4 \cdot W_\text{ici})$,这是 AllReduce 成本的 1/4。
更多关于 ReduceScatter
ReduceScatter 是一个比它最初看起来更基本的操作,因为它实际上是 AllGather 的导数,反之亦然。即如果在前向传递中我们有:
\[\textbf{AllGather}_X A[I_X] \rightarrow A[I]\]那么我们在反向模式导数 A’ 上执行 ReduceScatter 来推导分片的 A’:
\[\textbf{ReduceScatter}_X A'[I] \{ U_X \} \rightarrow A'[I_X]\]将 AllReduce 转换为 AllGather 和 ReduceScatter 还有一个方便的属性,即我们可以将最终的 AllGather 推迟到以后的某个时刻。
我们学到了什么? (What Have We Learned?)
- 数组的分片由 Mesh(命名 TPU 网格的物理硬件轴)和 Sharding(将网格轴名称分配给数组的逻辑轴)指定。
- 分片数组的算术运算与未分片数组完全相同,除非你沿分片轴执行收缩。在那之后,我们必须引入一些通信。
- TPU 使用大约 4 个核心通信原语:
- AllGather: $[A_X, B] \to [A, B]$
- ReduceScatter: $[A, B] \{U_X\} \to [A, B_X]$
- AllToAll: $[A, B_X] \to [A_X, B]$
- AllReduce: $[A_X, B]\{U_Y\} \to [A_X, B]$
- 这些操作的成本和延迟 不依赖于轴的大小(只要它们是带宽受限的),而只依赖于输入数组的大小和链路带宽。
| Operation | Description | Syntax | Runtime |
|---|---|---|---|
| AllGather | 沿轴收集分片数组的所有分片,移除下标。 | $[A_X, B] \to [A, B]$ | bytes / (bidirectional ICI bandwidth * num_axes) |
| ReduceScatter | 沿轴对部分求和的数组求和,并沿另一个轴对其分片(添加下标)。 | $[A, B] \{U_X\} \to [A_X, B]$ | 与 AllGather 相同 |
| AllReduce | 沿轴对部分求和的数组求和。移除 { Ux }。组合了 AllGather 和 ReduceScatter。 | $[A_X, B]\{U_Y\} \to [A_X, B]$ | 2 * AllGather |
| AllToAll | 收集(复制)一个轴并沿同一轴分片不同的维度。 | $[A, B_X] \to [A_X, B]$ | AllGather / 4 (对于双向环) |
一些练习题 (Some Problems to Work)
Question 1 [replicated sharding]: 一个数组分片为 $A[I_X, J, K, \ldots]$(即仅跨 $X$ 分片),网格为 Mesh({'X': 4, 'Y': 8, 'Z': 2})。$A$ 在所有芯片上占用的总字节数与一份数组副本的大小之比是多少?
点击这里查看答案。
我们的数组仅沿 X 分片,大小为 4,所以实际上每个分片的大小为 $[I / 4, J, K, \ldots] = \text{sizeof}(A) / 4$。由于我们的数组跨 Y 和 Z 复制,总大小为 $Y \cdot Z \cdot \text{sizeof}(A)$,所以总大小与单芯片大小之比为 $Y \cdot Z \cdot \text{sizeof}(A) / \text{sizeof}(A) = 16$。
Question 2 [AllGather latency]: 如果 $B=1024$ 和 $D=4096$ (bfloat16),在具有网格 Mesh({'X': 4, 'Y': 4, 'Z': 4}) 的 TPUv4p 4x4x4 切片上 $\text{AllGather}X([B_X, D_Y])$ 需要多长时间?$\text{AllGather}{XY}([B_X, D_Y])$ 呢?$\text{AllReduce}_Z([B_X, D_Y] {U_Z })$ 呢?
点击这里查看答案。
我们有 9e10 双向带宽。
- 因为我们只是在一个轴上 gather 而另一个是分片的,我们实际上是在 1 个轴上 gather $2BD / Y$ 字节。这将花费 $2BD / (\text{9e10} \cdot Y) = 2 \cdot 1024 \cdot 4096 / (\text{9e10} \cdot 4) = 23 \mu s$。
- 我们有两倍带宽,但 AllGather 完整数组,所以
T = 2BD / (2 * W) = 2*1024*4096 / (2 * 9e10) = 46us。这远离 4us 的延迟限制,所以我们很好。 - AllReduce 的成本是 AllGather 的两倍。每个分片大小为 $2BD / (X * Y)$,所以成本约为 $4BD / (X * Y * W)$,或大致
4 * 1024 * 4096 / (16 * 9e10) = 11.6us。
Question 3 [latency-bound AllGather]: 假设我们执行 $\text{AllGather}_X([B_X])$ 但 $B$ 非常小(比如说 128)。在 bfloat16 中,这需要多长时间?提示:你可能是延迟受限的。
点击这里查看答案。
我们的数组在 bfloat16 中总共只使用 256 字节,每设备仅 64 字节。由于我们在 TPU v4p 上有一个大小为 4 的轴,我们有环绕链接,所以我们可以双向发送。使用 4.5e10 单向带宽,每跳大约需要 64 / 4.5e10 ~ 0,所以我们绝对是延迟受限的。计算跳数,我们只需 2 跳即可完成完整 gather,所以大约 2us 是一个好的估计。
Question 4 [matmul strategies]: 为了执行 $X[B, D] \cdot_D Y[D_X, F] \to Z[B, F]$,在本节中我们告诉你执行 $\text{AllGather}_X(Y[D_X, F])$ 并乘以完全复制的矩阵(Case 2,策略 1)。或者,你可以乘以本地分片像 $X[B, D_X] \cdot_D Y[D_X, F] \to Z[B, F] \{U_X\}$(Case 4,策略 2),然后 $\text{AllReduce}_X(Z[B, F] \{ U_X\})$。每一个执行多少 FLOPs 和 comms?哪个更好,为什么?
点击这里查看答案。
让我们从基线(策略 1)开始。AllGather 的成本是 $2DF / W_\text{ici}$。一旦我们需要完全复制的数组,总计算时间是 $2BDF / C$(其中 $C$ 是加速器 FLOPs/s)。
相比之下,新策略(策略 2)对 $2BF$ 字节进行 AllReduce,成本为 $4BF / W_\text{ici}$,但执行的 FLOPs 少 $1 / X$。这意味着我们做 $2\cdot B\cdot D\cdot F / X$ FLOPs。因此,策略 2 的总时间大致为:
\[T_\text{total} = \max\left(\frac{2BDF}{X \cdot C}, \frac{4BF}{W_\text{ici}}\right)\]问题是:哪个更大? 策略 (2) 在 $D / (X \cdot C) > 2 / W_\text{ici}$ 时是计算受限的。如果 $B < 2550$,我们在两种情况下都是 comms 受限的,并且当 $D > 2B$ 时,策略 2 的 comms 时间更短。这通常是真的,所以如果我们的 batch 小,策略 2 有时会更好。当我们的 batch 大时,策略 2 通常更好,除非 $D$ 很小。
Appendix A: AllToAll
Question 10: Fun with AllToAll: 在上面的表格中,有人指出执行 AllToAll 的时间比 AllGather 或 ReduceScatter 低 4 倍(在吞吐量受限的情况下)。在这个问题中,我们将看到这个因子 4 来自哪里。
点击这里查看答案。
(1) Solution: 过程很简单:在算法的每一步,每个设备将矩阵的单分片“条带”发送给最近的邻居。这发生 $D-1$ 次。在单向情况下,总共 $N^2(1-\frac{1}{D})$ 个标量通过单个 ICI 链路。
(2) Solution: AllToAll 的关键区别在于,特定设备上的整个分片不需要传达给每个其他设备。总共传输的参数数量是 \((\text{size of A/B/C/D}) * (3 + 2 + 1)\)。也就是说,总共通过单个 ICI 链路传输的字节数是 $\frac{N^2(D-1)}{D \times 2}$。
(3) Solution: 因子为 $\frac{1}{4}$(相对于 AllGather 的 1/2,再结合双向带来的额外优势)。