这是 DeepMind Scaling Book 系列的第 9 部分。

如何分析 TPU 程序 (How to Profile TPU Programs)

How To Scale Your Model Part 9 (Part 8: Serving LLaMA | Part 10: Programming TPUs)

到目前为止,本系列完全是理论性的:基于硬件 rooflines 的粗略计算。这种理解能让你走得很远,但很多优化归结为实际细节:XLA 编译器如何工作,以及如何使用 JAX/Tensorboard Profiler 等分析工具来弄清楚当它失败时该怎么办。我们在这里讨论这个。

JAX Profiler:多用途 TPU 分析器 (The JAX Profiler: A Multi-Purpose TPU Profiler)

JAX 提供了一个多用途的 TPU 分析器,其中包含一堆有用的工具,用于了解程序运行时 TPU 上发生了什么。你可以使用 jax.profiler 模块来跟踪正在运行的程序,并记录从每个子组件的持续时间、每个程序的 HLO、内存使用情况等所有内容。例如,此代码将 trace 转储到 /tmp/tensorboard 中的文件,可以在 TensorBoard 中查看。

1
2
3
4
5
6
import jax
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (1024, 1024))
  y = x @ x
  y.block_until_ready()

一旦进入 TensorBoard,profiler 有几个关键选项卡可以帮助你理解你的程序:

  1. Trace Viewer 以时间轴的形式显示 TPU 上实际发生的详细时间轴。
  2. Graph Viewer 显示 HLO 图,让你看到程序的哪些部分相互馈送以及事物是如何分片的。
  3. Memory Profile 和 Memory Viewer: 这些显示你的程序使用了多少内存。

Trace Viewer

Trace Viewer 可能是 profiler 中最有用的部分。 Trace Viewer 显示了每个 TPU 核心上所有操作的时间顺序时间轴。

  1. 顶行 (XLA Ops) 显示实际的 TPU 操作(名称是 HLO 名称)。其他所有内容都是基于 jax.named_scope, jax.named_call 和 Python 堆栈跟踪的近似跟踪。
  2. 通过点击一个 XLA op,我们可以查看它在代码中的来源(对于理解 trace 很有用)并查看指向 Graph viewer 的链接。

Tip: 你可以使用“视频游戏”风格的控件导航 Trace Viewer,A/D 向左和向右平移,W/S 放大和缩小。这些控件使导航变得容易得多。

如何阅读 XLA op (How to read an XLA op)

HLO 实际上并不难读,它非常有用于理解上面的 trace 的给定部分对应什么。这是一个名为 fusion.3 的示例 op。

1
%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3

让我们将其分解为各个部分。

  • Op Name: fusion.3
    • Dot 或 fusion op 是一组包含至多 1 个矩阵乘法和可能一堆相关逐点 VPU-ops 的操作。
  • Shape/layout: bf16[32,32,4096]
    • 这是 op 的输出形状。我们可以看到 dtype 是 bf16(每个参数 2 字节)并且 [32,32,4096] 是形状。
  • Layout: {2,1,0:T(8,128)(2,1)}
    • {2,1,0:T(8,128)(2,1)} 告诉我们内存中轴的顺序(列优先,行优先等)和数组填充。
  • Memory location: S(1)
    • S(1) 告诉我们这个数组存在于 VMEM 中。S(0)(有时省略)是 HBM。S(2) 和 S(3) 是其他内存空间。
  • Arguments: bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32
    • 这个 op 有一个输入,一个名为 fusion.32 的 bf16 数组,具有特定的形状。这告诉我们什么函数馈送到这个函数。

Graph Viewer

虽然上面的一些融合看起来很复杂,但 XLA Graph Viewer 使它们更容易解析。盯着一堆 HLO 图并尝试将 HLO ops 映射到你正在分析的代码上是非常有帮助的。通过将鼠标悬停在一个框上,你通常会看到定义该函数的代码行。

看一个真实的例子 (Looking at a real(ish) example profile)

让我们分解一个简单的 Transformer 配置文件,从 FFW 块开始。

我们预计这需要多长时间? 首先,我们的每个数据并行分片的 batch size 是 8 * 1024 = 8192,所以我们应该稳固地受计算限制。在 8 个 TPUv2 核心上,我们预计它需要大约 2 * 32 * 1024 * 8192 * 32768 / (23e12 * 8) = 95.6ms,这几乎正是它所花费的时间(96ms)。太棒了!这意味着我们获得了极好的 FLOPs 利用率!

关于通信? 你会注意到隐藏在第二个 matmul 末尾的小融合。这是一个 ReduceScatter。在 TPUv2 4x2 上做 ReduceScatter,应该只需要在 1.2e11 双向带宽上跳一次。数组大小为 2*32*1024*8192,batch 轴分片 4 路,所以每个分片是 2*8*1024*8192=134MB。所以这应该大约需要 1.1ms。实际上需要多长时间? 配置文件中报告了 1.13ms。所以我们要么非常接近 roofline!

Memory Profile

Memory Profile 使得很容易将程序内存视为时间的函数。这对于调试 OOM 很有帮助。我们可以看到大约 7.5GB 分配给模型参数,大约 10GB 空闲。所以我们可以容纳更多的内存。

练习题 (Worked Problems)

Question 1: 查看 Colab/profile 并弄清楚这里看起来有什么可疑之处以及发生了什么。你能确切地告诉我正在发生什么计算以及每个操作在做什么吗?每个矩阵的真实形状是什么,它们是如何分片的?

点击这里查看答案。

这是两个矩阵乘法。你可以看到一个 reduce,两个大 fusion,和一个 all-reduce。通过观察带有 replica_groups=\{\{0,16,32,48,64,80,96,112\}, ...\} 的最终 AllReduce,我们可以看出我们在做 8 路模型并行。

Question 2: Transformer Colab 实现了一个简单的模拟 Transformer。获取朴素 Transformer 与 GSPMD 分区的基准。每个部分需要多长时间?应该只花多长时间?正在使用什么分片。尝试修复分片!

来源

Profiling - Part 9