这是 DeepMind Scaling Book 系列的第 8 部分。
在 TPU 上服务 LLaMA 3 (Serving LLaMA 3 on TPUs)
How To Scale Your Model Part 8 (Part 7: Inference | Part 9: Profiling)
让我们仔细看看我们如何在 TPU v5e 上服务 LLaMA 3-70B 模型。不同的模型在 roofline 下服务的成本有多高?它们的 KV cache 有多大?我们应该使用什么样的 batch size?在推理过程中参数和激活是如何分片的?让我们通过一些粗略的估算来计算生产中的延迟和吞吐量。
本节将探讨服务 LLaMA-3 需要什么,以及如何有效地完成。就像前面的“应用”部分一样,在查阅答案之前,试着用纸和笔自己计算一下!
LLaMA 的服务故事是怎样的? (What’s the LLaMA Serving Story?)
让我们回顾一下 LLaMA 3-70B 的样子(参考第 6 节):
| hyperparam | value |
|---|---|
| \(n_\text{layers}\) (L) | 80 |
| \(d_\text{model}\) (D) | 8,192 |
| \(d_{ff}\) (F) | 28,672 |
| \(n_\text{heads}\) (N) | 64 |
| \(n_\text{kv heads}\) (K) | 8 |
| \(d_\text{qkv}\) (H) | 128 |
| \(n_\text{embeddings}\) (V) | 128,256 |
我们应该在什么硬件上服务? 答案基本上是,无论哪个 FLOPs / 美元最便宜。出于这个原因,我们通常希望在 TPU v5e 上服务,这是我们目前的专用推理芯片。
每个 TPU v5e 有 16GB 的 HBM,这将要求我们相当积极地分片我们的模型。让我们先考虑一些可能对我们有用的基本量:
Question: LLaMA 3-70B 的每 token KV caches 有多大?你可以假设我们将它们存储在 int8 中。这决定了在给定的拓扑上我们的 batch size 可以有多大。
点击这里查看答案!
LLaMA 3-70B 有 8 个 KV heads,所以每 token 的大小是 2 * K * H * L = 2 * 8 * 128 * 80 = 160kB。
注意这有多大! 如果我们有 32k tokens 的序列长度(这很常见),这使用 162e3 * 32,768 = 5.3GB / sequence。对于 BS=240,这是 1.3TB!由于 TPU v5e 每个只有 16GB,我们需要大约 (70e9 + 1.3e12) / 16e9 = 86 个 TPU v5e 芯片才能甚至容纳这么多内存。还要注意,这与 70GB 的模型参数相比有多大。
Question: 假设我们要在 int8 中以 batch size 32 和 8192 序列长度服务 L3 70B(参数和 KVs)。这将总共使用多少内存?我们可以服务的最小切片是多少?
点击这里查看答案!
由于我们的 KVs 在 int8 中是 160e3 字节,我们的总 KV 内存是 160e3 * 8192 * 32 = 41.9e9 字节。我们的参数是 70e9 字节,因为我们每个参数有 1 个字节。因此,我们的总内存使用量是 41.9e9 + 70e9 = 112GB。
我们可以使用的最小切片将有 112e9 / 16e9 = 7 个 TPU,或者(四舍五入到偶数大小),TPU v5e 4x2。这将是一个紧密的配合,考虑到其他开销,我们可能无法完全适应,所以我们可能至少需要一个 4x4(或者降低 batch size)。
Question: 在此 batch size 和 TPU v5e 4x2 上的量化下,我们预计每解码步的延迟大致是多少?什么吞吐量(tokens / sec / chip)。至于 4x4 呢?假设我们在 bfloat16 中执行 FLOPs 并且一切都是完全分片的。
点击这里查看答案!
我们可以调用上一节中的公式:
\[\begin{align*} \tiny \text{Theoretical Step Time (General)} = \underbrace{\frac{\text{Batch Size} \times \text{KV Cache Size}}{\tiny \text{Total Memory Bandwidth}}}_{\text{Attention (always bandwidth-bound)}} + \underbrace{\max\left(\frac{2 \times \text{Batch Size} \times \text{Parameter Count}}{\text{Total FLOPs/s}}, \frac{\text{Parameter Size}}{\text{Total Memory Bandwidth}}\right)}_{\tiny \text{MLP (can be compute-bound)}} \end{align*}\]严格看内存带宽,我们的步进时间基本上是 (KV size + param size) / (8 * HBM bandwidth) = 112e9 / (8 * 8.1e11) = 17ms。所以理论上我们的步进时间大约是 17ms。 我们的吞吐量将是 32 / .017 = 1882 tokens / sec,或 1882 / 8 = 235 tokens / sec / chip。
思考吞吐量 (Thinking about throughput)
当我们优化吞吐量时,我们希望通过尽可能大的 batch size 来使计算受限。
Question: 在 TPU v5e 上使用 bfloat16 权重和激活,我们的 batch sizes 需要多大才能使 matmuls 受计算限制?如果我们在 int8 权重但执行 bfloat16 FLOPs 会怎样?
点击这里查看答案!
如第 7 节所述,对于任何 $B \ll D, F$ 的 bfloat16 matmul,我们有
\[\begin{equation*} T_\text{math} > T_\text{comms} \leftrightarrow \frac{2BDF}{2DF} \geq \frac{\text{TPU bfloat16 FLOPs/s}}{\text{HBM bandwidth}} = 240 \end{equation*}\]当我们的权重在 int8 时,我们在分母中失去了一个 2 倍因子,所以我们有 $2BDF / DF = 2B > 240$,或者同样 $B > 120$,是从前临界 batch size 的一半。这真的很有帮助!当我们做 int8 权重和 int8 FLOPs 时,我们必须使用 int8 值的 TPU FLOPs/s,从 1.97e14 (bfloat16) 变为 3.94e14,几乎翻了一番。这意味着我们回到了大约 $B > 240$ 的起点。
Question: 我们可以使用的最小 TPU v5e 拓扑是多少,用于服务 LLaMA 3-70B,使用 bfloat16, int8 和 int4 (KVs 和参数都用) 和 8k 上下文?对于这个,你可以认为 KV caches 可以忽略不计。
点击这里查看答案!
这很容易!如果我们接受微小的 batch size,那么唯一的限制是将参数内存放入 HBM,即 ceil(num_params * sizeof(dtype) / HBM per TPU。
| dtype | param size | KV size / token (bytes) | min TPU v5es | actual min slice | remaining HBM for KV caches | num KV caches @ 8k |
|---|---|---|---|---|---|---|
| bf16 | 140GB | 324kB | 8.75 | 4x4 = 16 chips | 116 | 43 |
| int8 | 70GB | 162kB | 4.38 | 4x2 = 8 chips | 58 | 43 |
| int4 | 35GB | 81kB | 2.81 | 2x2 = 4 chips | 29 | 43 |
这很酷!它告诉我们,如果我们愿意,我们可以将 LLaMA 70B 放在 TPU v5e 2x2 上。除了你会注意到 KV cache 的数量非常少。那是我们的 batch size!这意味着我们将获得糟糕的 FLOPs 利用率。我们会非常乐意使用更大的拓扑结构,以便将我们的 batch size 推高到 240。
Takeaway: 我们总是可以通过询问从 HBM 加载所有模型参数到 MXU 需要多长时间来下限解码延迟。当我们的 KV caches 很小时,你可以认为每一层都是一块一块地加载权重然后丢弃它们。除非我们使用非常大的 batch size 或大量的设备间通信,否则这通常是一个合理的界限 (在 1.5x 以内)。当我们的 batch size 较大时,我们需要模拟 KV cache 加载,因为它主导了参数。
Question: 现在让我们深入探讨分片问题。假设我们想要在 TPU v5e 4x8 上以 bfloat16 服务。在生成期间,我们应该在 TPU v5e 4x8 上为我们的模型使用什么分片?我们可以避免受通信限制吗?
点击这里查看答案!
如前一节所述,在生成期间我们实际上只有一个分片选项:模型并行。我们可以做多少直到我们受通信限制?正如我们在前一节中讨论的那样,我们的模型大致在以下情况下变得受通信限制
\[Y > \frac{F \cdot M_Y}{2200}\]对于 LLaMA 3-70B,我们有 F = 28,672,所以如果我们做 2 个轴的模型分片,这大约给我们 \(Y = 28672 \cdot 2 / 2200 = 26\),所以一般而言我们可以扩展到大约 16 个芯片而不受通信限制,这让我们使用 4x4 但不能使用 4x8。
Takeaway: 实际上我们不能用纯模型并行在 4x8 上服务。 我们在这里能做的最好是 4x2 或 也许 是 4x4。
关于预填充? (What about prefill?)
我们在这里大多忽略了预填充,因为它要简单得多。让我们把几个概念放在一起,思考端到端的图景。
Question: 假设我们在预填充期间实现 40% 的 FLOPs 利用率。在 16 个 TPU v5e 芯片上预填充 8192 需要多长时间?
点击这里查看答案!
在 8k tokens 时,我们是坚实的计算受限,所以我们只需要推理 FLOPs。我们知道我们的模型有 70e9 参数,所以每个前向传递使用 2 * 70e9 * B FLOPs。假设 40% MFU (FLOPs utilization),这给我们大约 2 * 70e9 * 8192 / (16 * 1.97e14 * 0.4) = 0.91s 的运行时间。与我们之前看到的数字相比,这实际上是相当多的!
可视化延迟-吞吐量权衡 (Visualizing the Latency Throughput Tradeoff)
坚持使用 LLaMA 70B,让我们实际看看生成期间不同 batch size 的延迟和吞吐量。
- 看看成本和延迟之间的权衡有多大。 以每 token 延迟加倍为代价,我们可以实现每 token 成本大约 100 倍的降低。
- 注意在 2k 上下文时,吞吐量实际上在达到 BS 120 roofline 时稳定在每芯片约 1 token / ms。随着序列长度增加,我们不再能将此 batch size 放入内存,所以我们永远不会达到完全饱和点。
- 注意在大 batch size 下,相同吞吐量的延迟要高得多,因为 KV 加载变得主导(而不是参数加载)。
我们可以通过将成本和延迟的来源分解为参数加载时间、KV 加载时间和 FLOPs 时间来更好地理解这一点。
这说明了很多。你可以看到最初参数加载代表了绝大多数延迟,直到 batch size 变得足够大,FLOPs 和 KV 加载变得更加显著。值得注意的是,在所有大于 2048 的序列长度下,我们在 KV cache 加载上花费的时间比在 FLOPs 上花费的时间更多!所以虽然我可以通过增加 batch size 来提高硬件利用率,但在长上下文长度下,KV 加载总是主导总步进时间。
Takeaway: 对于 LLaMA 3-70B,我们在几乎所有这些配置中都受到强烈的 KV cache 内存带宽限制(和 HBM 限制),突出了减少 KV cache 大小对于生成吞吐量的重要性。还要注意这里延迟/吞吐量权衡有多么剧烈。
代码很简单。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
num_chips = 16 # we fix 16 as the amount of total model parallelism we do
param_size = 70e9 # int8 means 1 byte per param
sequence_length = 8192 # can vary this
hbm_bandwidth = 8.20E+11 # v5e
flops = 1.97E+14 # v5e
param_size = bytes_per_param * param_count
def kv_cache_size(bs):
return 2 * bs * 128 * 8 * 80
def min_topology(bytes):
return 2 ** np.ceil(np.log2(bytes / 16e9))
def get_max_batch_size(max_num_chips: int = 16):
# for num_chips in topo_sizes:
batch_sizes = np.arange(1, 1024, 4)
kv_sizes = kv_cache_size(sequence_length * batch_sizes)
num_chips = min_topology(kv_sizes + param_size)
max_idx = np.where(num_chips <= max_num_chips)[0][-1]
return max_idx
max_idx = get_max_batch_size(num_chips, sequence_length, param_size) # get the largest batch size that can fit
batch_sizes = np.arange(1, 512, 1)[:max_idx]
kv_sizes = kv_cache_size(sequence_length * batch_sizes)
kv_comms_time = kv_sizes / (num_chips * hbm_bandwidth)
param_comms_time = param_size / (num_chips * hbm_bandwidth)
param_comms_time = np.asarray([param_comms_time] * batch_sizes.shape[0])
flops_time = 2 * param_count * batch_sizes / (num_chips * flops) # roughly true in a 2ND sense
mlp_time = np.maximum(flops_time, param_comms_time)
attn_time = kv_comms_time # always bandwidth-bound for generate
latency = 1000 * (mlp_time + attn_time)
throughput = batch_sizes / (latency * num_chips)
注意我们如何非常明确地将延迟分解为两个来源:KV 加载和参数加载,以及延迟如何受 FLOPs 或 comms 限制,以较大者为准。
练习题 (Worked Problems)
Question 1: LLaMA 3-405B 每个前向传递使用多少 FLOPs per-token?
Question 2: 假设我们要使用 int8 权重和 int8 KV caches 以 BS240 服务 LLaMA 3-8B。有多少字节用于 (a) 模型参数 (b) KV caches 和 (c) 峰值工作激活(大概)?我们可以运行它的最小拓扑是什么?
Question 3: 你将如何在 TPU v5e 上服务 LLaMA 3-405B?假设 int8 权重和 bfloat16 FLOPs。