这是 DeepMind Scaling Book 系列的第 10 部分。
在 JAX 中对 TPU 编程 (Programming TPUs in JAX)
How To Scale Your Model Part 10 (Part 9: Profiling | Part 11: Conclusions)
我们将在本节讨论如何在 JAX 中编写高效的 TPU 代码,包括如何使用 pmap、分片约束 (sharding constraints) 和 shard_map。你可以在 Google Colab 上运行本节中的代码示例。
JAX 中的并行是如何工作的? (How Does Parallelism Work in JAX?)
JAX 支持三种多设备编程思想:
- Compiler, take the wheel! (Auto Sharding): 让 XLA 编译器自动分割数组并决定添加什么通信。
- JAX, take the wheel! (Explicit Sharding): 你告诉编译器每个张量应该如何分片,JAX 处理分片传播。
- Just let me write what I mean, damnit! (Manual Sharding): 你使用
shard_map手动控制每个设备上的数据和通信。
| Mode | View? | Explicit sharding? | Explicit Collectives? |
|---|---|---|---|
| Auto | Global | ❌ | ❌ |
| Explicit | Global | ✅ | ❌ |
| Manual | Per-device | ✅ | ✅ |
1. 自动分片模式 (Auto sharding mode)
jax.jit 在 JAX 中扮演两个角色。如名所示,它将 Python 函数“即时”编译为字节码。但如果输入是分片的,或者用户指定了 in_sharding 或 out_sharding,它也允许 XLA 将计算分布在多个设备上,并根据需要添加通信。
例如,这是你如何编写一个分片的矩阵乘法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import jax
import jax.numpy as jnp
# Running on an TPU v5e 4x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
# This tells JAX to use this mesh for all operations, so you can just specify the PartitionSpec P.
jax.set_mesh(mesh)
# We create a matrix W and input activations In sharded across our devices.
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None)))
def matmul_square(In, W):
return jnp.einsum('bd,df->bf', jnp.square(In), W)
# We can explicitly compile the sharded matmul function here. This adds all the
# necessary comms (e.g. an AllReduce after the matmul).
jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile()
out = jit_matmul(In, W)
这是非常神奇的! 无论我们的程序多么复杂,Shardy 和 jit 都会尝试为所有中间激活找到分片,并根据需要添加通信。
2. 显式分片模式 (Explicit sharding mode)
显式分片(或“类型中的分片”)看起来很像自动分片,但分片传播发生在 JAX 级别!每个 JAX 操作都有一个分片规则,该规则获取 op 参数的分片并为 op 的结果生成分片。
1
2
3
4
5
6
7
8
import jax
import jax.numpy as jnp
import jax.sharding as shd
# Running on an TPU v5e 2x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'),
axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
# ...
对于某些操作,如何对结果进行分片是不明确的,在这种情况下,JAX 会抛出一个 trace-time 错误,并要求程序员显式提供 out_sharding 参数。
3. 手动分片模式 (Manual sharding mode via shard_map)
shard_map (或 pmap) 允许你编写在单个设备上运行的代码,并显式地进行通信(如 psum, all_gather)。
1
2
3
4
5
6
7
8
9
10
11
12
import jax
import jax.numpy as jnp
x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=jax.P(('x', 'y')))
# This function will operate on 1/8th of the array.
@jax.shard_map(in_specs=jax.P(('x', 'y')), out_specs=jax.P())
def slice_and_average(x):
assert x.shape == (512 // 8,)
return jax.lax.pmean(x[:4], axis_name=('x', 'y'))
out = slice_and_average(x)
为什么这样做而不是 jax.jit? 如果我们使用 jax.jit,slice_and_average 将看到数组的全局视图。我们将不得不切出这个非均匀切片,然后执行 XLA 必须正确解释的平均值。这里我们看到本地视图,只编写我们需要的通信。
练习题 (Worked Problems)
Problem 1: 编写一个 JAX 函数,计算每个 (X, Y) 分片内的平均值。
Problem 2: 让我们一起做一个基本的“混合专家 (mixture of experts)”模型。
Problem 3: 让我们调整示例以执行完整的 Transformer 堆栈。