这是 DeepMind Scaling Book 系列的第 2 部分。
关于 TPU 的一切 (How to Think About TPUs)
How To Scale Your Model Part 2 (Part 1: Rooflines | Part 3: Sharding)
这一节完全关于 TPU 如何工作,它们如何联网以实现多芯片训练和推理,以及这对我们最喜欢的算法性能有何影响。这对 GPU 用户也有很多有用的东西!
你可能也会喜欢阅读这篇关于 NVIDIA GPU 的新章节 第 12 节!
TPU 是什么? (What Is a TPU?)
TPU 基本上是一个专门用于矩阵乘法的计算核心(称为 TensorCore),连接到一堆高速内存(称为高带宽内存或 HBM)1。 图示如下:
Figure: TPU 芯片的基本组件。TensorCore 是左边的灰色框,包含矩阵乘法单元 (MXU)、向量单元 (VPU) 和向量内存 (VMEM)。
你可以认为 TensorCore 基本上就是一个非常好的矩阵乘法机器,但它还有一些其他值得注意的功能。TensorCore 有三个关键单元:
- MXU (Matrix Multiply Unit) 是 TensorCore 的核心。对于大多数 TPU 代次,它每 8 个周期使用脉动阵列执行一次
bfloat16[8,128] @ bf16[128,128] -> f32[8,128]矩阵乘法2(详见 附录 B)。- 在 TPU v5e 上以 1.5GHz 运行时,每个 MXU 约为
5e13bf16 FLOPs/s。大多数 TensorCores 有 2 或 4 个 MXU,所以例如 TPU v5e 的总 bf16 FLOPs/s 为2e14。 - TPU 还支持具有更高吞吐量的低精度矩阵乘法(例如,每个 TPU v5e 芯片可以做
4e14int8 OPs/s)。
- 在 TPU v5e 上以 1.5GHz 运行时,每个 MXU 约为
-
VPU (Vector Processing Unit) 执行通用的数学操作,如 ReLU 激活或向量间的逐元素加法或乘法。归约(求和)也在这里执行。附录 A 提供了更多细节。
- VMEM (Vector Memory) 是位于 TensorCore 中靠近计算单元的片上暂存器。它比 HBM 小得多(例如 TPU v5e 上为 128 MiB),但对 MXU 具有高得多的带宽。VMEM 的运作有点像 CPU 上的 L1/L2 缓存,但更大且由程序员控制。HBM 中的数据需要先复制到 VMEM 中,TensorCore 才能对其进行任何计算。
TPU 在矩阵乘法方面非常非常快。这是它们的主要工作,而且它们做得很好。TPU v5p 是迄今为止最强大的 TPU 之一,每个核心可以做 2.5e14 bf16 FLOPs/s,或者每芯片 5e14 bf16 FLOPs/s。一个包含 8960 个芯片的 Pod 可以做 4 exaflops/s。这是非常多的。这是世界上最强大的超级计算机之一。Google 有很多这样的机器3。
上面的图还包括其他一些组件,如 SMEM 和标量单元,它们用于处理控制流,并在 附录 A 中简要讨论,不管是理解关键。另一方面,HBM 很重要且相当简单:
- HBM (High Bandwidth Memory) 是一大块快速内存,用于存储供 TensorCore 使用的张量。HBM 通常有几十 GB 的容量(例如,TPU v5e 有 16GiB HBM)。
- 当需要进行计算时,张量通过 VMEM(见下文)从 HBM 流入 MXU,结果从 VMEM 写回 HBM。
- HBM 和 TensorCore 之间的带宽(通过 VMEM)称为“HBM 带宽”(通常约为 1-2TB/sec),它限制了在内存受限的工作负载中计算的速度。
通常,所有 TPU 操作都是流水线化和重叠的。 为了执行 matmul $X \cdot A \to Y$,TPU 首先需要将矩阵 $A$ 和 $X$ 的块从 HBM 复制到 VMEM,然后将它们加载到 MXU 中,MXU 将 8x128 的块(对于 $X$)和 128x128 的块(对于 $A$)相乘,然后将结果块逐块复制回 HBM。为了高效地做到这一点,matmul 是流水线化的,以便与 MXU 工作重叠地进行 VMEM 的复制。这允许 MXU 继续工作而不是等待内存传输,保持 matmuls 是计算受限的,而不是内存受限的。
这是一个关于如何从 HBM 执行逐元素乘积的示例:
Figure: 一个动画,显示了在 TPU 上执行的逐元素乘积,字节从 HBM 加载。注意字节是如何分块流出内存的,并且部分结果是如何在不等待完整数组实体化的情况下流水线写回的。
矩阵乘法看起来几乎完全相同,除了它将加载到 MXU 而不是 VPU/Vector 单元,并且加载和存储将以不同的顺序发生,因为相同的权重块用于多个激活块。你可以看到数据块流入 VMEM,然后流入 VREGs(向量寄存器),然后流入 Vector Unit,然后回到 VMEM 和 HBM。正如我们将看到的,如果从 HBM 到 VMEM 的加载速度慢于 Vector Unit(或 MXU)中的 FLOPs,我们就变成了“带宽受限”,因为我们的 VPU 或 MXU 缺乏工作。
关键要点 (Key takeaway): TPU 非常简单。它们将权重从 HBM 加载到 VMEM,然后从 VMEM 加载到脉动阵列,每秒可以执行大约 200 万亿次乘加运算。HBM $\leftrightarrow$ VMEM 和 VMEM $\leftrightarrow$ 脉动阵列带宽设定了 TPU 可以高效执行哪些计算的基本限制。
VMEM 和算术强度:VMEM 比 HBM 小得多,但它对 MXU 的带宽要高得多。正如我们在 第 1 章 中看到的,这意味着如果算法可以将所有输入/输出放入 VMEM,它遇到通信瓶颈的可能性就会小得多。当计算的算术强度较差时,这尤其有用:VMEM 带宽约为 HBM 带宽的 22 倍,这意味着从/向 VMEM 读取/写入的 MXU 操作只需要 10-20 的算术强度即可实现峰值 FLOPs 利用率。这意味着如果我们能将权重放入 VMEM 而不是 HBM,我们的矩阵乘法可以在更小的 batch size 下受限於 FLOPs。这意味着本质上具有较低算术强度的算法仍然可以是高效的。VMEM 实在太小了,这通常是一个挑战4。

一个 TPU 芯片通常(但并非总是)由两个 TPU 核心组成,它们共享内存,可以被认为是一个具有两倍 FLOPs 的大型加速器(称为“megacore”配置)。自 TPU v4 以来一直如此。较旧的 TPU 芯片具有独立的内存,均被视为两个独立的加速器(TPU v3 及更旧)。像 TPU v5e 这样的推理优化芯片每个芯片只有一个 TPU 核心。

芯片以 4 个一组排列在“托盘 (tray)”上,通过 PCIe 网络连接到 CPU 主机。这是大多数读者熟悉的格式,4 个芯片(8 个核心,虽然通常被视为 4 个逻辑 megacores)通过 Colab 或单个 TPU-VM 暴露。对于像 TPU v5e 这样的推理芯片,每个主机有 2 个托盘,而不是 1 个,但也只有每个芯片 1 个核心,给我们 8 个芯片 = 8 个核心5。

PCIe 带宽是有限的: 就像 HBM $\leftrightarrow$ VMEM 链路一样,CPU $\leftrightarrow$ HBM PCIe 连接具有特定的带宽,限制了你可以多快地从主机内存加载到 HBM 或反之亦然。TPU v4 的 PCIe 带宽每个方向为 16GB / second,所以比 HBM 慢近 100 倍。我们可以将数据加载/卸载到主机 (CPU) RAM,但速度并不快。
TPU 网络 (TPU Networking)
芯片通过 ICI 网络连接在 Pod 中。在老一代(TPU v2 和 TPU v3)、推理芯片(例如 TPU v5e)和 Trillium (TPU v6e) 中,ICI (“inter-chip interconnects”) 连接 4 个最近的邻居(具有边缘连接以形成 2D 环面)。TPU v4 和 TPU v5p 连接到最近的 6 个邻居(形成 3D 环面)。注意这些连接 不 经过它们的主机,它们是芯片之间的直接链路。

环面结构将任意两个节点之间的最大距离从 $N$ 减少到 $N / 2$,使通信更快。TPU 还有一个“twisted torus (扭曲环面)”配置,将环面包裹在莫比乌斯带状拓扑中,以进一步减少节点之间的平均距离。
TPU pods (通过 ICI 连接) 可以变得非常大: 最大 pod 大小(称为 superpod)对于 TPU v4 是 16x16x16,对于 TPU v5p 是 16x20x28。这些大型 pods 由 4x4x4 芯片的可重构立方体组成,通过 光学环绕链路6 连接,我们可以重新配置以连接非常大的拓扑。

也可以请求更小的拓扑(例如 2x2x1, 2x2x2),尽管没有环绕。这是一个重要的警告,因为它通常会使大多数通信的时间增加一倍。任何完整立方体的倍数(例如 4x4x4 或 4x4x8)都将由光学交换机提供环绕7。

TPU v5e 和 Trillium pods 由单个 16x16 2D 环面组成,在任何尺寸为 16 的轴上都有环绕(意味着 8x16 在长轴上有环绕)。TPU v5e 和 v6e (Trillium) 无法扩展到 16x16 环面之外,但 pods 仍然可以通过标准数据中心网络 (DCN) 相互通信,DCN 连接 TPU 主机。同样,可以请求更小的拓扑,而在 $<16$ 的维度上没有环绕。

这种最近邻连接是 TPU 和 GPU 之间的一个关键区别。GPU 通过一系列交换机连接,这些交换机近似于每个 GPU 之间的点对点连接,而不是像 TPU 那样的本地连接。通常,节点内的 GPU(H100 为 8 个 GPU,B200 NVL72 多达 72 个)是直接连接的,而更大的拓扑结构需要在每个 GPU 之间进行 O(log(N)) 跳。一方面,这意味着 GPU 可以在少量跳数内发送任意数据。另一方面,TPU 便宜得多(因为 NVLink Switch 很贵),布线更简单,并且可以扩展到更大的拓扑结构,因为每个设备的链路数量和每个设备的带宽是恒定的。在 这里 阅读更多。
ICI 相对于 DCN 非常快,但仍然比 HBM 带宽慢。 例如,一个 TPU v5p 具有:
- 每个芯片
2.5e12bytes/s (2.5 TB/s) 的 HBM 带宽。 - 每轴
9e10bytes/s (90 GB/s) 的 ICI 带宽,每芯片 3 个轴8。 - 每个 TPU
6.25e9bytes/s (6.25 GB/s) 的 DCN (出口) 带宽(通过每个主机上的 1-2 个 NIC)9。
这意味着当我们跨多个芯片分割模型时,我们需要小心避免用较慢的跨设备通信成为 MXU 的瓶颈。
多切片 (Multi-slice) 训练: 一组通过 ICI 连接的 TPU 称为一个 slice (切片)。不同的 slice 可以通过 DCN 相互连接,例如连接不同 pods 上的 slice。由于 DCN 是比 ICI 慢得多的连接,人们应该尽量限制计算等待来自 DCN 的数据的时间。DCN 是主机到主机的,所以要通过 DCN 将缓冲区从 TPU 传输到 TPU,我们需要先通过 PCIe 传输到主机,然后通过网络出口,然后通过目标主机网络入口,再通过 PCIe 进入 HBM。
关键要点 (Key Takeaways)
-
TPU 很简单,在大多数情况下可以被看作是一个连接到内存(超快)、通过 ICI 连接的其他芯片(相当快)以及通过 DCN 连接的数据中心其余部分(有些快)的矩阵乘法单元。
- 通信受限于我们各种网络带宽,按速度排序:
- HBM 带宽:在 TensorCore 及其关联的 HBM 之间。
- ICI 带宽:在一个 TPU 芯片及其最近的 4 或 6 个邻居之间。
- PCIe 带宽:在 CPU 主机及其关联的芯片托盘之间。
- DCN 带宽:在多个 CPU 主机之间,通常是不通过 ICI 连接的主机。
-
在切片内,TPU 仅通过 ICI 连接到它们最近的邻居。 这意味着切片中远距离芯片之间的 ICI 通信需要先跳过中间的芯片。
-
权重矩阵需要填充到至少大小 128(TPU v6 上为 256)的两个维度以填满 MXU(实际上,较小的轴被填充到 128)。
-
低精度矩阵乘法往往更快。 对于支持它的代次,TPU 可以比 bfloat16 FLOPs 快大约 2x/4x 做 int8 或 int4 FLOPs。VPU 操作仍然在 fp32 中执行。
- 为了避免成为 TPU 计算单元的瓶颈,我们需要 确保跨每个通道的通信量与其速度成正比。
TPU 规格 (TPU Specs)
以下是我们芯片的一些具体数字:
| Model | Pod size | Host size | HBM capacity/chip | HBM BW/chip (bytes/s) | FLOPs/s/chip (bf16) | FLOPs/s/chip (int8) |
|---|---|---|---|---|---|---|
| TPU v3 | 32x32 | 4x2 | 32GB | 9.0e11 | 1.4e14 | 1.4e14 |
| TPU v4p | 16x16x16 | 2x2x1 | 32GB | 1.2e12 | 2.75e14 | 2.75e14 |
| TPU v5p | 16x20x28 | 2x2x1 | 96GB | 2.8e12 | 4.59e14 | 9.18e14 |
| TPU v5e | 16x16 | 4x2 | 16GB | 8.1e11 | 1.97e14 | 3.94e14 |
| TPU v6e | 16x16 | 4x2 | 32GB | 1.6e12 | 9.20e14 | 1.84e15 |
Host size 指的是连接到单个主机的 TPU 拓扑(例如 TPU v5e 有一个连接到 4x2 拓扑中 8 个 TPU 的 CPU 主机)。以下是互连数据:
| Model | ICI BW/link (one-way, bytes/s) | ICI BW/link (bidi, bytes/s) |
|---|---|---|
| TPU v3 | 1e11 | 2e11 |
| TPU v4p | 4.5e10 | 9e10 |
| TPU v5p | 9e10 | 1.8e11 |
| TPU v5e | 4.5e10 | 9e10 |
| TPU v6e | 9e10 | 1.8e11 |
我们包括单向带宽和双向带宽,因为单向带宽更符合硬件事实,但双向带宽更常出现在涉及完整环的方程中10。
每个 TPU 的 PCIe 带宽通常约为 1.6e10 bytes / second(TPU v6e 为 3.2e10),而 DCN 带宽通常约为每个 TPU 6.25e9 bytes / second(TPU v6e 为 12.5e9,TPU v5e 为 3.125e9)。
练习题 (Worked Problems)
这些数字有点枯燥,但它们让你对模型性能做出基本的 roofline 估计。让我们做几个问题来解释为什么这很有用。你将在第 3 部分看到更多例子。
问题 1 [限制 LLM 延迟]: 假设你想从分布在 32 个 TPU v4p 上的 bf16 200B 参数模型中采样。将所有参数从 HBM 加载到脉动阵列需要多长时间?提示:使用上面的数字。
点击这里查看答案。
答案: 我们正在 32 个芯片上加载 sizeof(bf16) * 200e9 = 400e9 字节,意味着每个芯片 12.5e9 字节,每个具有 1.23e12 的 HBM 带宽。所以加载大约需要 10ms。
这很酷,因为 这是从模型采样的延迟的合理下限。每个采样步骤都需要从 HBM 加载所有参数,所以它不能小于 10 ms。实际上,使得 batch size 很小,这接近于可以实现的。
问题 2 [TPU 细节]: 考虑一个完整的 TPU v5e pod。总共有多少个 CPU 主机?多少个 TPU TensorCores?整个 pod 的总 FLOPs/s 是多少?总 HBM 是多少?对 TPU v5p pod 做同样的练习。
点击这里查看答案。
答案: 对于 TPU v5e,每个 pod 是 16x16,每个主机是 4x2 切片,所以我们有 16*16 / 8 = 32 个主机。对于 TPU v5e,每个 TPU 只有一个核心,所以我们有 256 个 TensorCores。总 FLOPs/s 是 16*16*2e14 = 5.1e16 (bfloat16)。每个芯片有 16GB HBM,所以那是 256 * 16 = 4TB 内存。
对于完整的 TPU v5p pod,我们有 16x20x28 个芯片,每个主机是 2x2x1,所以我们有 16*20*28 / 2*2 = 2,240 个主机。对于 TPU v5p,每个 TPU 两个 TensorCores,所以我们有 8960 * 2 = 17,920 个核心。总 FLOPs/s 是 8960 * 4.5e14 = 4e18 (bfloat16)。每个芯片有 96GB HBM,所以那是 8960 * 96 = 860TB 内存。
问题 3 [PCIe 运算强度]: 想象我们被迫将大权重矩阵 $A$(类型 $\text{bfloat16}[D, F]$)和一批激活 $x$(类型 $\text{bfloat16}[B, D]$)存储在主机 DRAM 中,并想对它们进行矩阵乘法。这在一个主机上运行,我们使用连接到它的单个 TPU v6e 芯片。你可以假设 $B \ll D$,且 $F = 4D$(我们将在以后的章节中看到为什么这些是合理的假设)。我们需要多小的 batch size $B$ 才能保持在 PCIe 上受限于 FLOPs?假设 PCIe 带宽为 1.5e10 bytes / second。
点击这里查看答案。
答案: 我们必须执行 $2BDF$ 次浮点运算,每个芯片每秒可以执行 9.2e14 次浮点运算。这需要 $2BDF / 9.2e14$ 秒来执行。我们必须从 DRAM 加载 $2DF + 2BD$ 字节,并写回 $2BF$ 字节。我们受 PCIe 传输速度的瓶颈限制,所以我们需要 $2 \cdot (BD + DF + BF) / 1.5e10$ 秒来向 TPU 传输数据。由于我们希望计算时间长于权重加载时间,假设我们可以将所有权重加载与计算重叠,我们希望 $2BDF / 9.2e14 > 2 \cdot (BD + DF + BF) / 1.5e10$。我们可以利用假设 $B \ll D$ 和 $F = 4D$ 简化这一点,得到
或
\[B > \frac{9.2 \times 10^{14}}{1.5 \times 10^{10}} \simeq 61{,}000\]问题 4 [通用 matmul 延迟]: 假设我们想将大小为 int8[16384, 4096] 的权重矩阵乘以大小为 int8[B, 4096] 的激活矩阵,其中 B 是某个未知的 batch size。假设我们开始是在 1 个 TPUv5e 上。
- 作为 B 的函数,这个乘法需要多长时间?提示:计算从 HBM 加载数组需要多长时间以及乘法实际需要多长时间可能会有所帮助。哪个是瓶颈?
- 如果我们想从 VMEM 运行此操作怎么办?作为 B 的函数,它需要多长时间?
点击这里查看答案。
答案: (1) 我们需要执行的浮点运算次数是 $2 \cdot 4096 \cdot 16384 \cdot B = 1.3 \times 10^{8} \cdot B$。所以 $T_{\text{math}} = (1.3 \times 10^{8} \cdot B) / 3.94 \times 10^{14}$ 秒。我们需要从 HBM 加载 $16384 \cdot 4096 + 4096 \cdot B$ 字节到 VMEM,并从 VMEM 写回 $16384 \cdot B$ 字节到 HBM。这意味着 $T_{\text{comms}} = (6.7 \times 10^{7} + 2 \times 10^{4} \cdot B) / 8.1 \times 10^{11}$ 秒。假设通信和计算尽可能重叠,整个乘法将大约耗时
\[\max\{T_{\text{math}}, T_{\text{comms}}\} = \max\left\{ \frac{6.7 \times 10^{7} + 2 \times 10^{4} \cdot B}{8.1 \times 10^{11}}, \frac{1.3 \times 10^{8} \cdot B}{3.94 \times 10^{14}} \right\}\]当 $\frac{6.7 \times 10^{7} + 2 \times 10^{4} \cdot B}{8.1 \times 10^{11}} < \frac{1.3 \times 10^{8} \cdot B}{3.94 \times 10^{14}}$ 时,或者等价地 $B > 271$ 时,我们将受限于 FLOPs。这比我们在下面推导出的 240 数字稍大,因为我们考虑了 \(D\) 和 \(F\) 的全部影响。
(2) 如果我们改从 VMEM 加载,让我们考虑 HBM $\leftrightarrow$ VMEM 带宽的 22 倍作为 VMEM 到 MXU 的带宽。这将我们的数据加载分母从 8.1e11 变为 1.78e13,我们得到 $B > 11$。注意在实践中,我们不能将所有 VMEM 带宽都用于加载 $W$,所以实际上它将更接近 20。
问题 5 [ICI 带宽]: 假设我们有一个 TPU v5e 4x4 切片。假设我们想从 TPU{0,0} 发送一个类型为 bfloat16[8, 128, 8192] 的数组到 TPU{3, 3}。假设 TPU v5e 的每跳延迟是 $1\mu s$。
- 第一个字节多久到达目的地?
- 总传输需要多长时间?
点击这里查看答案。
答案: 在 TPUv5e 中我们有 2D 连接。因为我们只有 4x4 切片(没有大小为 16 的轴),我们没有环绕连接。因此,我们的目标芯片可以从两个端口接收数据,同样我们的源芯片可以从两个端口发送数据。我们要传输的数据量是 2 * 8 * 128 * 8192 = 1.7e7 字节。我们可以同时从两个端口传输(即向右发送一半数组,向下发送一半),所以我们得到每秒传输 2 * 4.5e10 = 9e10 字节,这意味着大约需要 1.7e7 / 9e10 = 188us 才能传输整个数组(假设我们是带宽受限的)。在 4x4 切片中,我们在芯片 $(0, 0)$ 和 $(3, 3)$ 之间有六跳,因为对于少于 16 个芯片的轴没有环绕链路。由于每一跳的延迟约为 $1\mu s$,第一个字节将在大约 6us 后到达,总传输将耗时 188us。
问题 6 [汇总,困难]: 想象你有一个大矩阵 A: int8[128 * 1024, 128 * 1024] 均匀分片在 TPU v5e 4x4 切片上,但在每个芯片上都卸载到主机 DRAM。假设你想把整个数组复制到 TPU{0, 0} 并乘以一个向量 bf16[8, 128 * 1024]。这需要多长时间?提示:使用上面的数字。
点击这里查看答案。
答案: 让我们从我们要执行的操作开始。我们的数组大约是 16GB。从上面的表看,TPU v5e 主机有 4x2 拓扑,所以 4x4 有 2 个主机。因此,由于我们的数组是均匀分片的,每个主机实际上包含 1/2 的数组块,或 8GB。我们需要把这些块都复制到 TPU{0,0},这给了我们两个选择:
- 我们可以通过 DCN 复制,然后通过 PCIe 将整个未分片数组加载到 HBM。
- 我们可以将我们的分片数组加载到它们对应的 TPU 上,然后通过 ICI 执行 gather,然后在 TPU{0,0} 上执行 matmul。
很明显选项 (2) 更好。DCN 比 ICI 慢,我们要更愿意通过许多 PCIe 链路(主机 0 上的 8 个)加载大数组,而不仅仅是几个。这是系统一部分的图表。如上所述,注意 TPU 通过 ICI 连接到邻居(即使跨主机),所有 TPU 通过 PCIe 连接到它们的主机 CPU,主机通过 DCN 连接。

现在让我们看看每一部分需要多长时间:
-
PCIe load: 我们正在通过 16 个 PCIe 链路加载 16GB 的块,每个具有
1.5e10bytes/second 带宽。因此这将大约耗时 66ms。 -
ICI copy: 每个 TPU 现在有 16GB / 16 = 1GB 的数组。我们的 ICI 带宽是每条链路
9e10bytes/second 双向,你会从上面的图表中注意到,在这种拓扑中,对于 TPU{0,0},TPU v5e 上的 4 条 ICI 链路中只有 2 条在使用。由于 TPU{0,0} 需要沿 2 个轴以4.5e10bytes/s/link 接收总共 15GB,我们可以将时间下限设为15e9 / (4.5e10 * 2) = 167ms。实际上这可能无法实现,因为负载非常不均匀,但可能在 2 倍以内。正如你将在第 2 节(注:原文指 Sharding 章节)中看到的,执行完整的 AllGather 也将大约耗时16e9 / (4.5e10 * 2),所以这接近最优。 -
HBM $\rightarrow$ MXU load: 为了执行我们最后的 matmul,我们需要将这些 16e9 字节加上 bf16[8, 128 * 1024] 数组(另外 2MB,可忽略)通过 HBM 带宽加载到 MXU,这将耗时
16e9 / 8.1e11 = 19ms。 -
FLOPs: 我们正在执行总共 $2 \cdot 8 \cdot 128 \cdot 1024 \cdot 128 \cdot 1024 = 2.7 \times 10^{11}$ FLOPs,因为我们可以执行
1.97e14bf16 FLOPs/s,我们得到 1.3ms。
总时间的上限是所有这些时间的总和,但由于 TPU 显通常可以重叠这些操作,我们可以将其视为受最慢部分瓶颈限制的流水线问题。假设这是真的,那么答案大约是 150-200ms。
Part 2 就到这里!对于 Part 3,涵盖分区和跨 TPU 通信,点击这里。
附录 (Appendix)
Appendix A: 更多关于 TPU 内部 (More on TPU internals)
在这里,我们将更深入地探讨 TPU 的内部操作。除非另有说明,我们将提供 TPU v5p 的规格。
VPU
VPU 是 TPU 的向量算术核心。VPU 由一个二维 SIMD 向量机 (VPU) 组成,执行逐元素算术操作,如 vadd(向量加法)或 vmax(逐元素最大值),以及一组向量寄存器 (VREGs),用于保存 VPU 和 MXU 的数据。
VREGs: 每个 TPU v5p 核心有 64 个 32-bit VREGs(TPU v4 中为 32 个),每个核心总共有大约 64 * 8 * 128 * 4 = 256kB 的 VREG 内存(或整个芯片的 2 倍,因为我们有两个核心)。TPU v5p 每个周期可以从 VMEM 加载 3 个寄存器,每个周期向 VMEM 写入 1 个寄存器。
VPU: VPU 是一个形状为 (8, 128) 的 2D 向量算术单元,其中 128 维被称为 lane axis,8 维被称为 sublane axis。v5 上的每个 (lane, sublane) 对包含 4 个标准的浮点 ALU,它们相互独立。VPU 在每个 ALU 中以一个周期执行大多数算术指令(如 vadd 或向量加法),延迟为 2 个周期,所以例如在 v5 中,你可以在每个周期从 VREGs 将 4 对 f32 值相加。典型的 VPU 指令可能看起来像 {v2 = vadd.8x128.f32 v0, v1},其中 v0 和 v1 是输入 VREGs,v2 是输出 VREG。
所有 lanes 和 sublanes 每个周期以纯 SIMD 方式执行相同的程序,但每个 ALU 可以执行不同的操作。所以我们可以在单个周期内处理例如 1 个 vadd 和 1 个 vsub,每个都对两个完整的 VREG 进行操作并将输出写入第三个。
Pop Quiz [计算 VPU 吞吐量]: 使用上述信息,计算 TPU v5p 可以执行多少向量 FLOPs/s。TPU v5p 的时钟速度大约为 1.75GHz。
点击这里查看答案。
答案: 每个周期,每个核心可以在 8 * 128 ALU 上执行 4 个向量指令。这为整个芯片提供了 8 * 128 * 4 * 2 FLOPs/cycle,或 8 * 128 * 4 * 2 * 1.75e9 = 1.4e13 FLOPs/s。注意这比 2e14 的 MXU FLOPs/s 小多少(大约 10 倍)。
Reductions: 通常,跨 sublane 维度的通信或归约比跨 lane 维度更容易。例如,VPU 支持 intra-lane shuffle 操作,可以在大约一个周期内沿大小为 8 的轴滚动。这可以用来沿 sublane 维度执行高效的归约(只需 shuffle 4, 2, 和 1 并做 3 对逐元素求和)。
Cross-lane 归约要难得多,涉及一个称为 XLU 或 “cross lane unit” 的独立硬件单元,它很慢且相当昂贵。
与 GPU 的比较: 对于熟悉 NVIDIA GPU 的人来说,VPU 中的每个 ALU 类似于一个 CUDA 核心,单个 VPU lane 类似于一个 “Warp Scheduler”,即通常执行 SIMD 算术的 32 个 CUDA 核心的集合。lane 内的归约相当容易,但如果我们需要跨 lane,我们需要传输至少 VMEM/XLU/SMEM,这要慢得多。详见 GPU 章节。
标量核心 (Scalar Core)
标量核心是 TPU 的控制单元。它提取和分派所有指令并执行从 HBM 到 VMEM 的传输,并且可以编程做标量元数据工作。因为标量核心是单线程的,这的一个副作用是 TPU 的每个核心每个周期只能创建一个 DMA 请求。
为了说明这一点,单个标量核心控制一个 VPU(由 4096 个 ALU 组成)、4 个 MXU、2 个 XLU 和多个 DMA 引擎。每个计算单元的高度偏斜控制是硬件效率的来源,但也限制了以任何有趣的方式进行数据依赖向量化的能力。
Appendix B: 脉动阵列是如何工作的? (How does a systolic array work?)
TPU MXU 的核心是一个 128x128 脉动阵列(TPU v6e 上为 256x256)。当完全饱和时,脉动阵列每 8 个时钟周期可以执行一次 bfloat16[8,128] @ bf16[128x128] -> f32[8,128]11 乘法。
- 在其核心,脉动阵列是一个 2D
128x128(=16,384) 的 ALU 网格,每个都能执行乘加操作。 - 权重 (W,
128x128输入) 从上方传递下来(称为 RHS),而输入 (X,8x128输入) 从左侧传入(称为 LHS)。
这是一个简化的动画,显示了将一组权重(蓝色)与一组激活(绿色)相乘。你会注意到权重 (RHS) 首先部分加载,对角线地,然后激活被送入,也是对角线地。在下面的每一帧中,我们将所有重叠的绿色和蓝色单元相乘,将结果与从上方传入的任何残差求和,然后将结果依次向下传递一个单元。

这是此动画的更通用版本,显示输出从计算中流出:

这是一个图表,显示了如何在多个 RHS 和 LHS 数组之间进行流水线处理:

随着权重 (RHS) 和激活 (LHS) 的加载,会有一个初始流水线气泡。在该初始气泡之后,可以加载新的输入和权重,而无需额外的气泡。
这是一个 bf16[2, 3] x bf16[3, 3] 矩阵乘法的糟糕动画,你可以把它想象成一个 2x3 权重矩阵与批次 1 大小 3 的输入激活的 matmul。这与前几张幻灯片相比是旋转的,输入流向右边而不是下面,但你可以大致看到结构。

我们可以有效地将其流水线化以乘以大矩阵而不会有太大的流水线气泡。话虽如此,重要的是我们的矩阵形状要大于 MXU 的边尺寸,通常是 128x128。一些 TPU(自 TPU v3 以来)有多个 MXU,TPU v3 为 2 个,TPU v4/5 为 4 个,所以我们需要确保分块尺寸大于 128 * MXU 数量。这里 是这方面的一个好动画。
Trillium (TPU v6e) 有一个 256x256 脉动阵列,这意味着它可以执行 4 倍以上的 FLOPs / cycle。这也意味着你的张量尺寸需要两倍大才能完全利用 MXU。
这篇博文 有另一个关于固定权重矩阵的脉动阵列乘法的精彩动画。
脚注
-
TPU v6e (Trillium) 有一个 256x256 MXU,而所有前几代都使用 128x128 ↩︎
-
TPU,尤其是它们的脉动阵列,之所以是如此强大的硬件加速器,是因为矩阵乘法是少数几个使用 $O(n^3)$ 计算对应 $O(n^2)$ 字节的算法之一。这使得普通 ALU 很容易受限于计算而不是内存带宽。 ↩︎
-
我们有时会谈论 VMEM 预取,这指的是提前在 VMEM 中加载权重,以便我们可以掩盖 matmul 的加载成本。例如,在正常的 Transformer 中,我们有时可以在 attention 期间将我们的大前馈权重加载到 VMEM 中,如果我们受限于内存带宽,这可以隐藏权重加载的成本。这要求我们的权重足够小或分片足够多,以适应单个层进入 VMEM 且有剩余空间。 ↩︎
-
在 Cloud TPU VM 上,每个托盘作为单独 VM 的一部分暴露,所以再次可见 4 个核心。 ↩︎
-
光学交换机只是具有相同 ICI 带宽的可重构连接。它只是让我们连接立方体,同时保留环绕链路。 ↩︎
-
注意
2x2x4没有任何环绕,因为它们由光学交换机提供,只在完整立方体上可用。然而,TPU v5e 8x16 将 在长轴上有环绕,因为它不使用可重构光学网络。 ↩︎ -
上面的页面列出了 100 GB/s 的带宽,这与这里列出的略有不同。TPU ICI 链路具有略微不同的带宽,具体取决于正在执行的操作。你通常可以毫无顾虑地使用本文档中的数字。 ↩︎
-
TPU v6e 有 12.5e9 bytes/s,v5e 有 3.125e9 bytes/s。 ↩︎
-
我们所说的双向带宽是指可以在单个链路上双向发送的总字节数,或者同样地,从单个 TPU 沿特定轴发出的总字节数,假设我们可以有效地使用两个链路。当我们有一个功能环时,也就是我们在特定轴上有环绕连接时,这为真。这发生在当我们要么有一个完整的 16 轴推理芯片,要么有一个轴是 4 的倍数的训练芯片 (v*p) 时。我们更喜欢使用双向带宽,因为它经常出现在涉及双向通信的计算中。 ↩︎
-
如果你不熟悉这种记法,它的意思是:将一个 bfloat16 元素的
8x128矩阵乘以一个 bfloat16 元素的128x128矩阵,并将结果存储在一个 float32 元素的8x128矩阵中。 ↩︎