这是 DeepMind Scaling Book 系列的第 7 部分。

关于 Transformer 推理的一切 (All About Transformer Inference)

How To Scale Your Model Part 7 (Part 6: Training LLaMA | Part 8: Serving LLaMA)

在 Transformer 上执行推理可能与训练非常不同。部分原因是因为推理增加了一个需要考虑的新因素:延迟 (latency)。在本节中,我们将全面探讨从从模型中采样单个新 token 到作为推理引擎的一部分在许多加速器切片上高效扩展大型 Transformer。

Transformer 推理的基础 (The Basics of Transformer Inference)

所以你已经训练了一个 Transformer,你想用它来生成一些新的序列。归根结底,基准测试分数的上升和损失曲线的下降只是衡量是否会发生有趣事情的代理!

采样在概念上很简单。我们输入一个序列,我们最喜欢的 Transformer 会吐出 \(\log p(\text{next token}_i \vert \text{previous tokens})\),即所有可能的下一个 token 的对数概率。我们可以从这个分布中采样并获得一个新的 token。追加这个 token 并重复这个过程,我们就会得到一个作为提示 (prompt) 延续的 token 序列。

我们刚刚描述了 Transformer 采样的朴素实现,虽然它有效,我们在实践中从不这样做,因为我们每次生成 token 时都会重新处理整个序列。这个算法在 FFW 上是 \(O(n^2)\),在注意力机制上是 \(O(n^3)\) 来生成 \(n\) 个 tokens!

我们如何避免这种情况? 事实证明,我们可以从每次前向传递中保存一些中间激活,这让我们可以避免重新处理以前的 tokens。具体来说,由于给定的 token 在点积注意力期间只关注以前的 tokens,我们可以简单地将每个 token 的 key 和 value 投影写入一个称为 KV cache 的新数据结构中。

考虑到这一点,推理有两个关键部分:

  • Prefill (预填充): 给定一个长提示 (prompt),我们同时处理提示中的所有 tokens,并将结果激活(特别是 key-value 投影)保存在 “KV cache” 中。我们还保存最后一个 token 的 logits。
  • Generation (生成): 给定一个 KV cache 和先前的 logits,我们从 logits 中增量采样一个 token,将该 token 喂回 Transformer,并为下一步产生一组新的 logits。我们还将那个新 token 的 KV 激活追加到 KV cache 中。我们重复此操作,直到我们遇到特殊的 <EOS> token 或达到某个最大长度限制。

通过使用 KV cache 采样,我们将生成 $n$ 个 tokens 的时间复杂度降低到 FFW 上的 \(O(n)\) 和注意力上的 \(O(n^2)\),因为我们从不重新处理以前的 token。然而,生成一个序列仍然需要许多次前向传递——这正是当你查询 Gemini 或 ChatGPT 时发生的事情,结果正流回给你。

我们实际上想要优化什么? (What do we actually want to optimize?)

在继续之前,值得强调推理的一个全新方面:延迟。虽然在训练期间我们只关心吞吐量(每秒每芯片处理的总 token 数),但在推理期间,我们必须担心我们产生 token 的速度(首字延迟 (Time to First Token, TTFT)每输出 token 延迟 (TPOT))。

  • 离线批量推理: 仅关心推理的批量成本,对单个样本的延迟视而不见。
  • Chat 接口/流式任务: 需要在规模上廉价运行,同时具有低 TTFT 并生成足够快的 token 以超过人类阅读速度。
  • 边缘推理: 只需要以尽可能低的延迟一次服务一个用户。

最大化硬件利用率仍然至关重要,有助于降低成本和 TTFT,但与训练不同,它并 不一定 转化为所有上下文中单个用户的更好体验。

线性操作:是什么限制了我们? (Linear operations: what bottlenecks us?)

我们所有的线性操作在概念上都是相同的,无论它们存在于 MLP 块还是注意力中。它们的算术强度取决于 batch size。对于单个矩阵乘法 $\text{bf16[B, D]} @ \text{bf16[D, F]}$:

\[T_\text{math} = \frac{\text{Computation FLOPs}}{\text{Accelerator FLOPs/s}} = \frac{2BDF}{\text{Accelerator FLOPs/s}}\] \[T_\text{comms} = \frac{\text{Communication Bytes}}{\text{Bandwidth Bytes/s}} = \frac{2BD + 2FD + 2BF}{\text{Bandwidth Bytes/s}}\]

为了受计算限制,我们需要 \(T_\text{math} \geq T_\text{comms}\),或:

\[\frac{2BDF}{2BD + 2DF + 2BF} \approxeq \frac{2BDF}{2DF} = B \geq \frac{\text{Accelerator FLOPs/s}}{\text{Bandwidth Bytes/s}}\]

Takeaway: Transformer matmuls 是计算受限的 当且仅当 每副本 token batch size 大于 $B_\text{crit} = C / W_\text{hbm}$。对于 TPU v5e 上的 bf16 激活,这大约是 240 tokens。对于 H100,大约是 280 tokens。

预填充期间,所有矩阵乘法基本上总是计算受限的。因此,简单地最大化硬件利用率或 MFU 就足以最大化每芯片吞吐量和延迟。 生成期间,总 token batch size 必须大于 $B_{\text{crit}}$ 才能在线性/前馈操作上受计算限制(TPU v5e 上 bf16 参数为 240)。因为生成是逐个 token 串行发生的,这要求我们将多个请求批处理在一起,这很难!

关于 Attention? (What about attention?)

当我们看点积注意力操作时,事情变得更加复杂,特别是既然我们必须考虑 KV caches。 对于预填充,\(S=T\),所以算术强度是 \(\Theta(T)\)。这意味着只要我们的序列长度相当大,我们就很好! 但在生成期间,\(S \gg T = 1\),所以算术强度 \(\approx 1\)。这意味着我们实际上无法做任何事情来提高生成期间注意力的算术强度。所以我们在注意力期间基本上总是内存带宽受限的!

Takeaway: 在预填充期间,注意力通常对于任何合理的序列长度(大约 > 480 tokens)都是计算受限的,而在生成期间,我们的算术强度很低且恒定,所以我们总是内存带宽受限的。

LLM 延迟和吞吐量的理论估算 (Theoretical estimates for LLM latency and throughput)

对于生成期间的小 batch size(这很常见),我们可以通过假设我们在注意力和 MLP 块中都是内存带宽受限的来下限我们的每步延迟:

\[\begin{equation*} \text{Theoretical Min Step Time} = \frac{\text{Batch Size} \times \text{KV Cache Size} + \text{Parameter Size}}{\text{Total Memory Bandwidth}} \end{equation*}\]

同样,对于吞吐量:

\[\begin{equation*} \text{Theoretical Max Tokens/s} = \frac{\text{Batch Size} \times \text{Total Memory Bandwidth}}{\text{Batch Size} \times \text{KV Cache Size} + \text{Parameter Size}} \end{equation*}\]

随着我们的 batch size 增长,FLOPs 开始主导参数加载,所以实际上我们有更通用的方程:

\[\begin{align} \tiny \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\tiny \text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\tiny \text{MLP (can be compute-bound)}} \end{align}\]

Takeaway: 如果你关心生成吞吐量,请使用尽可能大的每芯片 batch size。任何高于 TPU 算术强度($B_\text{crit}$,通常为 120 或 240)的每芯片 batch size 都将最大化吞吐量。

关于内存? (What about memory?)

在推理期间,我们存储一份参数副本。主要区别是 KV cache。这些是所有过去 tokens 的 keys 和 value 投影。总大小为

\[\text{KV cache size} = 2 \cdot \text{bytes per float} \cdot H \cdot K \cdot L \cdot T\]

这可以很快变得非常大。对于 LLaMA-13B,单个 8192 序列的 KV cache 在 bf16 下是 6.7GB。仅仅 4 个这样的序列就超过了我们参数的内存使用量!

提高生成吞吐量和延迟的技巧 (Tricks for Improving Generation Throughput and Latency)

  • Grouped multi-query attention (aka GMQA, GQA): 我们可以减少 KV heads 的数量,并与多个 Q heads 共享它们。这有效地增加了注意力计算的算术强度。
  • Quantization: 推理通常对参数和 KV 的精度不太敏感。通过量化参数和 KV cache(例如到 int8, int4, fp8 等),我们可以节省内存带宽。
  • Paged Attention: 这是一种改进,将 KV caches 存储在 OS 风格的页表中,主要避免了填充 KV caches。这意味着每个 batch 只使用它需要的内存。

在多个加速器上分发推理 (Distributing Inference Over Multiple Accelerators)

预填充 (Prefill): 从 roofline 的角度来看,预填充几乎与训练相同,几乎所有相同的技术和权衡都适用——模型(Megatron)并行,序列分片(对于足够长的上下文),流水线,甚至 FSDP 都是可行的!

生成 (Generation): 生成是一个更复杂的野兽。

  1. FSDP 是不可能的: 因为我们在将参数和 KV caches 从 HBM 加载到 MXU 时受内存限制,我们不想通过 ICI 移动它们。
  2. 没有理由做数据并行: 纯数据并行没有帮助,因为它复制了我们的参数。
  3. 没有序列 = 没有序列分片
  4. 这主要留给我们模型分片的变体

Takeaway: 我们在生成期间唯一的选择是模型并行的变体。我们旨在移动激活而不是 KV caches 或参数,因为后者更大。

设计高效的推理引擎 (Designing an Effective Inference Engine)

连续批处理 (Continuous batching): 传统的批处理要求所有序列同时完成。连续批处理(或迭代级调度)允许在个别序列完成后立即插入新序列,大大提高了吞吐量。

前缀缓存 (Prefix caching): 如果多个请求共享相同的前缀(例如系统提示),我们可以只计算一次 KV Cache 并重用它。这对于带有长系统提示的聊天机器人非常有效。

练习题 (Worked Problems)

Question 1: 上述模型有多少参数?它的 int8 每 token KV caches 有多大?

点击这里查看答案。

Parameter count: 18.4 billion parameters. KV caches 每 token 是 2 * 64 * 8 * 256 = 262kB

Question 2: 假设我们想在一个 TPUv5e 4x4 切片上服务这个模型,并且可以完全分片我们的 KV cache。我们可以适应的最大 batch size 是多少?

点击这里查看答案。

Batch size 为 7。如果 $K=1$,约为 56。

Question 3: 假设参数在 TPU v5e 4x4 切片上完全分片,将所有参数从 HBM 加载到 MXU 需要多长时间?

点击这里查看答案。

大约 1.3ms

Appendix A: How real is the batch size > 240 rule?

我们在上面提供的简单规则,即我们的 batch size 必须大于 240 tokens 才能受计算限制,大致正确,但忽略了 TPU 在其他操作不使用所有可用 HBM 时预取权重的能力。

Appendix B: 2D Weight Stationary sharding

随着拓扑结构的增长,如果我们能够访问更高维度的网格(如 TPU 的网格),则可以进一步细化这一点,称为 “2D Weight Sharding”

Appendix C: Latency bound communications

我们的通信在 \(\text{total bytes} < W_{ICI} \times 1e-6\) 时变得受延迟限制。例如,对于 $Y$ 上的模型并行,我们在 int8 中当 \(Y > BD / 45,000\) 时受限。

Appendix D: Speculative Sampling

投机采样是另一种用吞吐量换取更好每 token 延迟的强大杠杆。然而,在 batch size 受限的情况下(例如小硬件占用空间或大 KV caches),它变成了双赢。

来源

Inference at Scale - Part 7