这是 DeepMind Scaling Book 系列的第 10 部分。

在 JAX 中对 TPU 编程 (Programming TPUs in JAX)

How To Scale Your Model Part 10 (Part 9: Profiling | Part 11: Conclusions)

我们将在本节讨论如何在 JAX 中编写高效的 TPU 代码,包括如何使用 pmap、分片约束 (sharding constraints) 和 shard_map。你可以在 Google Colab 上运行本节中的代码示例。

JAX 中的并行是如何工作的? (How Does Parallelism Work in JAX?)

JAX 支持三种多设备编程思想:

  1. Compiler, take the wheel! (Auto Sharding): 让 XLA 编译器自动分割数组并决定添加什么通信。
  2. JAX, take the wheel! (Explicit Sharding): 你告诉编译器每个张量应该如何分片,JAX 处理分片传播。
  3. Just let me write what I mean, damnit! (Manual Sharding): 你使用 shard_map 手动控制每个设备上的数据和通信。
Mode View? Explicit sharding? Explicit Collectives?
Auto Global
Explicit Global
Manual Per-device

1. 自动分片模式 (Auto sharding mode)

jax.jit 在 JAX 中扮演两个角色。如名所示,它将 Python 函数“即时”编译为字节码。但如果输入是分片的,或者用户指定了 in_shardingout_sharding,它也允许 XLA 将计算分布在多个设备上,并根据需要添加通信。

例如,这是你如何编写一个分片的矩阵乘法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import jax
import jax.numpy as jnp

# Running on an TPU v5e 4x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))

# This tells JAX to use this mesh for all operations, so you can just specify the PartitionSpec P.
jax.set_mesh(mesh)

# We create a matrix W and input activations In sharded across our devices.
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None)))

def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

# We can explicitly compile the sharded matmul function here. This adds all the
# necessary comms (e.g. an AllReduce after the matmul).
jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile()

out = jit_matmul(In, W)

这是非常神奇的! 无论我们的程序多么复杂,Shardy 和 jit 都会尝试为所有中间激活找到分片,并根据需要添加通信。

2. 显式分片模式 (Explicit sharding mode)

显式分片(或“类型中的分片”)看起来很像自动分片,但分片传播发生在 JAX 级别!每个 JAX 操作都有一个分片规则,该规则获取 op 参数的分片并为 op 的结果生成分片。

1
2
3
4
5
6
7
8
import jax
import jax.numpy as jnp
import jax.sharding as shd

# Running on an TPU v5e 2x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'), 
                     axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
# ...

对于某些操作,如何对结果进行分片是不明确的,在这种情况下,JAX 会抛出一个 trace-time 错误,并要求程序员显式提供 out_sharding 参数。

3. 手动分片模式 (Manual sharding mode via shard_map)

shard_map (或 pmap) 允许你编写在单个设备上运行的代码,并显式地进行通信(如 psum, all_gather)。

1
2
3
4
5
6
7
8
9
10
11
12
import jax
import jax.numpy as jnp

x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=jax.P(('x', 'y')))

# This function will operate on 1/8th of the array.
@jax.shard_map(in_specs=jax.P(('x', 'y')), out_specs=jax.P())
def slice_and_average(x):
  assert x.shape == (512 // 8,)
  return jax.lax.pmean(x[:4], axis_name=('x', 'y'))

out = slice_and_average(x)

为什么这样做而不是 jax.jit? 如果我们使用 jax.jitslice_and_average 将看到数组的全局视图。我们将不得不切出这个非均匀切片,然后执行 XLA 必须正确解释的平均值。这里我们看到本地视图,只编写我们需要的通信。

练习题 (Worked Problems)

Problem 1: 编写一个 JAX 函数,计算每个 (X, Y) 分片内的平均值。

Problem 2: 让我们一起做一个基本的“混合专家 (mixture of experts)”模型。

Problem 3: 让我们调整示例以执行完整的 Transformer 堆栈。

来源

Programming TPUs in JAX - Part 10