from flax.linen import partitioning as nn_partitioning
def forward(x):
# ...run some code here.
# note x here has shape [batch x seq_len x embed]
= nn_partitioning.with_sharding_constraint(x, ("batch", None, "embed"))
x # ...continue more code here.
return x
As of this writing in late-2022, large language models (LLMs) can now easily exceed 10B+ parameters (and the largest at 100B+ parameters). Pure data parallel strategies are no longer viable as the model itself no longer fits on single devices. Fortunately, research and engineering in scaling them have not slowed down; in the JAX (Bradbury et al., 2018) ecosystem in particular we now have pjit
, enabling an orthogonal way to parallelize models called tensor parallelism. In this post, we’ll explore the mathematical underpinnings of tensor parallelism, and learn how to implement it for a 15B param language model using pjit
.
Intro: Parallelism
Data Parallelism
Until recently, large scale training of deep learning models have primarily used data parallelism:
- Each device stores a full copy of the model, and receives a “shard” of the batch (if the full batch is 8 training examples, split along 2 devices each device receives 4 examples).
- Each device independently computes the loss, and its gradient (w.r.t. the parameters) using its data shard.
- Only once during each step, they synchronize their gradients and update their own copy of the model.
As long as a full copy of the model1 fits on device, this general strategy can scale gracefully to the typical maximum of 8 GPUs on a single host, and was the likely strategy used to train the “big” (213 million params) Transformer with in the original Attention is All You Need (Vaswani et al., 2017, p. 7) paper 2. Properly optimized, data parallelism scales to even hundreds of GPUs on multiple hosts.
However, data parallelism isn’t enough when the model itself can no longer fit on a single device. This is where model parallelism comes in.
Tensor Parallelism
Model parallelism is when we split the model itself across multiple devices. Tensor Parallelism (“sharding”) is one of two ways to do this; the other is Pipeline Parallelism (“pipelining”). The latter is briefly discussed at the end, but the focus here really is on the former.
Tensor parallelism is the answer to this question: what if we could compute the activations of every layer of our model, distributed across all our devices?
Suppose we have 4 devices: with standard data parallelism we make each device compute all the embedding dimensions for 1/4th of the batch:
But perhaps we could make each device compute 1/4th the embedding dimensions for the entire batch, like this:
Even more: instead of sharding on one axis, we could shard both axes. What if we arranged these 4 devices in a \(2\times2\) mesh, such that the first (top left) device computed 1/2 the embedding dimensions for 1/2 the batch?
This is the big idea behind tensor parallelism: arranging our devices into a 2D mesh, and then sharding both our weights and activations on both axes, for all the layers. That is, each device holds a single “shard” of every layer in the model. When done properly, it is possible to run calculations with only one full copy of the model distributed across all the devices.
We’ll start ground up: at the level of the dot products themselves, and see how sharding allows us to do very large matrix multiplies, by trading off increased communication for reduced memory use. Then, we’ll scale it up to a full model in JAX, implementing sharding with pjit
on a 15B language model for inference, with focus on the exact code changes, keeping them minimal.
Partitioning: Intuition
Intro: Dot products
Let’s start with an observation: any dot product between two vectors can be broken down into the sum of multiple smaller dot products. Suppose:
\[ a = \begin{bmatrix} 1 \\ 0 \\ 2 \\ -1 \end{bmatrix}, b = \begin{bmatrix} -1 \\ 2 \\ 0 \\ 2 \end{bmatrix} \]
Then, the dot product of these two vectors of length 4 is \[a \cdot b = (1 \times -1) + (0 \times 2) + (2 \times 0) + (-1 \times 2) = -3\]
But we could easily re-write that expanded calculation as \[\textcolor{BurntOrange}{\underbrace{[(1 \times -1) + (0 \times 2)]}_\text{-1}} + \textcolor{Plum}{\underbrace{[(2 \times 0) + (-1 \times 2)]}_\text{-2}}\]
Each of these two terms individually is also a dot product of two vectors of length 2. Recoloring the original vectors, we can imagine them as composed of two “partitioned”-vectors:
\[ a = \begin{bmatrix} \textcolor{BurntOrange}{1} \\ \textcolor{BurntOrange}{0} \\ \textcolor{Plum}{2} \\ \textcolor{Plum}{-1} \end{bmatrix} \;\; b = \begin{bmatrix} \textcolor{BurntOrange}{-1} \\ \textcolor{BurntOrange}{2} \\ \textcolor{Plum}{0} \\ \textcolor{Plum}{2} \end{bmatrix} \]
Now, say I wanted my friend to help out with this tedious calculation. If I calculated the dot product with the first partition of each vector (getting back \(\textcolor{BurntOrange}{-1}\)), they’d only need to return the result (\(\textcolor{Plum}{-2}\)) of their partition (and not their entire sub-vectors) for me to calculate the full dot product, \((-1)+(-2)=-3\).
Intro: Matrix multiplies
Let’s build on this with another observation: In a matrix multiply \(AB=C\), \(C\) is simply a storage mechanism for the pairwise dot-products of all the (row) vectors of \(A\) and (column) vectors of \(B\)3. Specifically, let:
\[A = \begin{bmatrix} \textcolor{LimeGreen}{1} & \textcolor{LimeGreen}{0} & \textcolor{LimeGreen}{2} & \textcolor{LimeGreen}{-1} \\ 2 & 1 & 0 & -2 \end{bmatrix} \;\; B = \begin{bmatrix} 0 & \textcolor{LimeGreen}{-1} \\ 1 & \textcolor{LimeGreen}{2} \\ 2 & \textcolor{LimeGreen}{0} \\ 0 & \textcolor{LimeGreen}{2} \end{bmatrix}\;\; AB = C = \begin{bmatrix} 4 & \textcolor{LimeGreen}{-3} \\ 1 & -4 \end{bmatrix} \]
\(A\)’s first row vector and \(B\)’s second column vector should seem familiar: we just took their dot products. And as expected, the element of the first row, second column of \(C\) is that dot product \(-3\). This perspective also neatly explains two facts about matrix multiplication:
- Why \(C\) is a \(2 \times 2\) matrix: \(A\) has two row vectors, and \(B\) has two column vectors, resulting in a \(2 \times 2\) matrix to capture all the pairwise dot products. (Likewise, if \(A\) had \(3\) row vectors, \(C\) would be of shape \(3 \times 2\)).
- Why the “inner axes” (\(A\) being \(2 \times \textcolor{LimeGreen}{4}\), \(B\) being \(\textcolor{LimeGreen}{4} \times 2\)) have to match: we can’t take dot products of vectors of different lengths. Take note of this “inner axes” terminology, we’re about to build on this right now!
Both combined, we have the general rule for the shapes: \(\underbrace{A}_{n\times d} \underbrace{B}_{d\times m} = \underbrace{C}_{n\times m}\)
Partitioning: In JAX
Now that we have a deep understanding of how to shard a neural network, let’s write some code! We’re in luck here: I’m currently working on a port of the ESM2 (Lin et al., 2022) protein language model into Flax (Heek et al., 2020), so much of the examples will be directly pulled from the working codebase.
For context: the model built here is a BERT (Devlin et al., 2019) style, encoder-only Transformer. It is a 15B param model with 48 layers, and 5120 dimensional embeddings as seen in the previous section. Each encoder layer has two sublayers as previously described: a self-attention sublayer followed by a feedforward network sublayer. Let’s start by understanding pjit
’s programming model, then using it to progressively build the full model.
The pjit
programming model
A good way to think of pjit
is a supercharged jax.pmap
. If you recall, pmap
runs the same program on multiple devices, each with a different shard of input data over the batch axis. pjit
is more flexible: it allows us to shard both the data and weights (and when training, even the optimizer states) in whatever configuration we please over a mesh. To do so, pjit
requires three things from us:
- A mesh specification, mapping the “logical” devices on the 2D (or higher-D) mesh to the physical devices available.
- The sharding spec of all tensors being passed as input to, and returned as output from the function.
- Sharding constraints for select intermediate tensors inside the function. This isn’t strictly necessary (XLA GSPMD will try to find a viable layout), but can lead to improved memory usage.
Note what isn’t here: JAX doesn’t need us to insert any of the collective ops we discussed. It uses a constraint based model, where we specify sharding constraints for the “big”, memory intensive tensors, and it automatically determines the sharding pattern for all other intermediate tensors in the function, as well as any collective ops that need to be inserted to meet these constraints.
Applying constraints to a FFN
In a previous section, we looked at the sharding spec the GSPMD paper proposed for the FFN in a transformer layer. To summarize in a table:
Tensor | Sharding Spec [shape] |
---|---|
\(\text{activation: embedding}\) | \(X, \_, Y\) [batch, seq_len, embed] |
\(W_{\text{in}}\) | \(X, Y\) [embed, hidden] |
\(\text{activation: hidden}\) | \(X, \_, Y\)[batch, seq_len, hidden] |
\(W_{\text{out}}\) | \(Y, X\) [hidden, embed] |
This sharding spec is for any generic, dense Transformer. The code below is the second sublayer (the FFN network) of an encoder layer in the ESM2 model. We apply this sharding spec to the weights on lines 11 and 21, and to the activations on lines 13 and 23:
# ... we apply a layer norm and multi-head attention before this.
# Create second residual block (LayerNorm + MLP)
residual = x
x = nn.LayerNorm(name="final_layer_norm", epsilon=1e-5)(x)
# Create + apply first MLP layer with weight + activation sharding constraints.
x = partitioning.Dense(
self.ffn_embed_dim,
name="fc1",
shard_axes={"kernel": ("embed_kernel", "hidden")},
)(x)
x = nn_partitioning.with_sharding_constraint(x, ("batch", None, "hidden"))
# Don't approximate gelu to avoid divergence with original PyTorch.
x = nn.gelu(x, approximate=False)
# Create + apply second MLP layer with weight + activation sharding constraints.
x = partitioning.Dense(
self.embed_dim,
name="fc2",
shard_axes={"kernel": ("hidden", "embed_kernel")},
)(x)
x = nn_partitioning.with_sharding_constraint(x, ("batch", None, "embed"))
x = residual + x
return x
The activation sharding specs are applied as in the initial example: we just with_sharding_constraint
. But there’s two new things:
- There’s a new
shard_axes
argument being passed into the layer definition on lines 11 and 21. - We’re using the
partitioning.Dense
layer instead of the standardnn.Dense
.
Let me elaborate on what’s going on here.
Putting it all together
We omitted a key detail in the opening example: in the real forward pass (the .apply
method) we need to pass in both esm_sharded_params
, and the data batch
. Since the params are an input argument, they will also need a sharding spec. The params
in Flax are a PyTree
(specifically, a nested dict) and so the sharding spec is a nested dict with the same structure. There’s some plumbing here, so let’s go through it step by step:
Because the ShardMixIn
.sow
’s the sharding metadata into the module, this metadata is available at model initialization with the .init
method. Let’s initialize the 15B model, and inspect the shapes of the parameters of layer 42:
Code
import functools
import flax.linen as nn
import jax
import jax.numpy as jnp
from esmjax.modules import modules
from esmjax.modules import partitioning
= 5120
embed_dim = 40
num_heads = 48
num_layers
= nn.Embed(33, embed_dim)
embedding = functools.partial(modules.EncoderLayer, num_heads, embed_dim, embed_dim * 4)
block_fn = modules.ESM2(embedding, block_fn, num_layers)
esm2
= jax.random.PRNGKey(0)
key = jnp.array([[0, 1, 2]]) arr
We can see that the 5120-dimensional embeddings are projected to produce embeddings for 40 heads, with 128 dims each.
# jax.eval_shape replaces all actual arrays with ShapeDtypeStruct
# This avoids memory use, *and* allows us to inspect the param shapes.
= jax.eval_shape(esm2.init, key, arr)
params 'params']['42']['self_attn'] params[
FrozenDict({
k_proj: {
bias: ShapeDtypeStruct(shape=(40, 128), dtype=float32),
kernel: ShapeDtypeStruct(shape=(5120, 40, 128), dtype=float32),
},
out_proj: {
bias: ShapeDtypeStruct(shape=(5120,), dtype=float32),
kernel: ShapeDtypeStruct(shape=(40, 128, 5120), dtype=float32),
},
q_proj: {
bias: ShapeDtypeStruct(shape=(40, 128), dtype=float32),
kernel: ShapeDtypeStruct(shape=(5120, 40, 128), dtype=float32),
},
v_proj: {
bias: ShapeDtypeStruct(shape=(40, 128), dtype=float32),
kernel: ShapeDtypeStruct(shape=(5120, 40, 128), dtype=float32),
},
})
We can also see the axis metadata generated when calling the .init
method:
'params_axes']['42']['self_attn'] params[
FrozenDict({
k_proj: {
kernel_axes: AxisMetadata(names=('embed_kernel', 'heads', None)),
},
out_proj: {
kernel_axes: AxisMetadata(names=('heads', None, 'embed_kernel')),
},
q_proj: {
kernel_axes: AxisMetadata(names=('embed_kernel', 'heads', None)),
},
v_proj: {
kernel_axes: AxisMetadata(names=('embed_kernel', 'heads', None)),
},
})
Only the params that we’ve specified a sharding constraint over exist in this PyTree
. To pass into pjit
, we use a utility function to convert the names into mesh axes, and replicate the structure of the full params. The AxisMetadata
are replaced with proper PartitionSpec
s, and all other params have their sharding pattern set to None
, meaning full replication.12
= params.pop("params_axes")
params, params_axes = partitioning.get_params_axes(params,
esm_axes
params_axes, =partitioning.DEFAULT_TPU_RULES)
rules'params']['42']['self_attn'] esm_axes[
FrozenDict({
k_proj: {
bias: None,
kernel: PartitionSpec('X', 'Y', None),
},
out_proj: {
bias: None,
kernel: PartitionSpec('Y', None, 'X'),
},
q_proj: {
bias: None,
kernel: PartitionSpec('X', 'Y', None),
},
v_proj: {
bias: None,
kernel: PartitionSpec('X', 'Y', None),
},
})
We now pass this sharding spec (esm_axes
) into the pjit
definition. Then, we have a fully sharded inference method, distributing the computation work of this 15B model across all 8 cores of a TPU. You can find a fully runnable notebook here.
= pjit.pjit(
apply_fn apply,
esm.=(esm_axes, P("X", None)),
in_axis_resources=P("X", None, "Y"),
out_axis_resources
)
with maps.Mesh(mesh.devices, mesh.axis_names), nn_partitioning.axis_rules(
partitioning.DEFAULT_TPU_RULES
):= apply_fn(esm_sharded_params, batch) embeds
Program Trace
Upto now, we’ve handwaved the fact that there’s a lot of communication going on in this forward pass. How much time on the forward pass are we spending on these collective communication ops? The short answer: on a TPUv2-8, about 20%13. The way to answer this is a program trace, and JAX makes this easy: here’s the full trace of all 48 layers of the ESM2 15B model on TPU3 (of 0 to 7) of a TPUv2-8, taking about 4.25s to complete inference with a batch size of 32:
Let’s zoom in on layer 42. We can tell from this trace that the FFN sublayer (fc1 and fc2) takes more time to execute than the self-attention sublayer. On the top bar we have the XLA Ops
, the direct, device-level ops being executed. Most of these are fusion
ops, a combination of fused multiplies and elementwise ops (e.g. addition, subtraction):
Let’s zoom in more into the FFN sublayer. Amusingly, XLA has decided to allgather fc2
’s weight matrix (the selected purple box) before the matmuls of fc1
and fc2
. This is the power of JIT-compilation: XLA is able to re-order operations as needed for better performance. It’s also inserted a reduce-scatter over the results (the rightmost, blue fusion op). Overall, the FFN sublayer takes 54ms, and 8ms are spent on collective communication and data reformatting ops, about ~15%14:
In summary, for a ~20% performance tradeoff, we can now run inference with only one copy of the model stored across all our devices! I want you to explore this trace yourself, so here’s the link to Perfetto; just hit Open trace file
and upload this trace, and go play around!
Conclusion: Beyond Tensor Parallelism
Tensor parallelism is powerful, allowing us to scale from 1 GPU/TPU to all 8 connected GPU/TPUs, and when using larger slices of a TPU pod, even further (PaLM was trained using just tensor parallelism, on two full TPUv4 pods with 3072 chips each, Sec. 4). There’s three concluding thoughts I’d like to leave you with:
Pipeline parallelism: Given the large volume of communication involved, tensor parallelism is only viable when there is fast (ideally 1TB/s+) interconnect between devices. This is true for TPUs all the way up to an entire pod; however, GPUs only have fast interconnect (e.g. NVLink) in groups of 8 on a single host. Between hosts, the interconnect is slower (e.g. commercial cloud is typically on the order of ~100GB/s 15), meaning a different strategy is necessary.
If tensor parallelism is slicing a model “horizontally” (each layer is sharded across all devices), pipeline parallelism is slicing it “vertically” (device 1 can hold layer 0,1 device 2 holds layers 2,3, and so on). The only communication is when activations move between layers, not inside a layer. The problem is that it leads to “bubbles” where devices are inactive (Y. Huang et al., 2019). Large GPU clusters tend to use tensor parallelism for all 8 GPUs connected on a single host, and pipeline parallelism between hosts to make the best of both strategies.
More automation:
pjit
is incredibly flexible, capable of accommodating any sharding pattern we can come up with. The GSPMD paper covers even more “strange” cases such as sharding convolutions over spatial axes, across multiple devices. However, we still do need to specify a sharding pattern, which for non-experts can be challenging. There’s a lot of exciting work going on in frameworks such as Alpa (Zheng et al., 2022), which automate this entirely, and I’m excited to see where this line of research is headed.Larger models? Scaling language models has been a strategy that continues to work with no clear signs of slowing down. But a substantial fraction of scaling up goes into learning factual knowledge about the world than the semantics of language. Retrieval-augmented models such as RETRO (Borgeaud et al., 2022) and Atlas (Izacard et al., 2022) are much smaller (the largest RETRO model is only 7.5B params). However, they introduce a new axis (retrieval time) to the current trio of compute, memory use and intra/inter-host communication, and I’m curious to learn where the bottlenecks will arise as this strategy is scaled up.
Further Reading
If you’d like to keep learning more about parallelism at scale, here’s a couple places to help you get started:
- How to Train Really Large Models on Many GPUs? (Weng, 2021): This is a great blog post providing a “big picture” overview of the multiple types of parallelism possible on GPU clusters, as well as other memory saving strategies.
- Scalable Training of Language Models using JAX pjit and TPUv4 (Yoo et al., 2022): A technical report from Cohere detailing how they use
pjit
, data and tensor parallelism to scale their training on TPUv4 pods (without needing pipeline parallelism) - Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (Shoeybi et al., 2019): This paper explored sharding the data and weights strictly on the outer axes (Case 2 only), motivated by a need to minimize inter-device communication on large GPU clusters.
- Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (Narayanan et al., 2021): This is a follow-up paper, exploring large scale parallelism on GPU clusters with a fusion of tensor parallelism (Case 2 only) combined with pipeline parallelism.
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Rajbhandari et al., 2020): This paper looked at the technical optimizations needed to store only one copy of a model across all devices (Zero-DP, Stage 3), finding this increases communication volume by only 1.5x over baseline data parallelism.
- GSPMD: General and Scalable Parallelization for ML Computation Graphs (Xu et al., 2021): We used the sharding spec introduced in this paper; the paper as whole discusses extensively about propagating user annotations across a computation graph, and the technical considerations involved.
Acknowledgements
Really grateful to Robert Dumitru for reading an early draft of this, providing honest feedback on the bits that “didn’t quite make sense”, and how they could be clearer. Access to TPUs was generously provided through the TPU Research Cloud.
Thanks for taking the time to read through this, and I hope you learned something new! Best place to leave feedback would be either my Twitter, or this form here. See you around!
References
Footnotes
and a copy of the optimizer state and gradients when we’re training, so potential total memory use of upto 3x the size of the model itself.↩︎
Specifically, in section 5.2 the authors note “We trained our models on one machine with 8 NVIDIA P100 GPUs.”↩︎
Why the row vectors of \(A\) and the column vectors of \(B\)? Why not just the column vectors of both? This is mostly due to convention, as confusing as it can be for newcomers. I like the Mnemonic RAC-B (“rack b”), rows of A, columns of B.↩︎
For an \(N \times N\) matrix, matrix multiplication has time complexity \(O(N^3)\) while matrix addition is \(O(N^2)\). For a large enough matrix, the speedup from the parallel matrix multiplies can outweigh the cost of communicating then adding afterwards.↩︎
You may also recognize this as just
jax.lax.psum
, which it is!↩︎Of course, not accounting for the communication overhead of the AllReduce. For systems that are compute-bound, splitting the actual multiplication across multiple devices and syncing them together (due to the \(O(N^3)\) compute costs of a matrix multiply vs the \(O(N^2)\) memory transfer) is still often worth it.↩︎
Again, the outer axes in a matrix multiply are individual vectors; the inner axes are the feature dimensions of those vectors, as we saw in the intro. Feel free to scroll back if you need to recap!↩︎
Note that in math contexts this is often written as \(Wx\), where \(x\) is a column vector. In Python, data points are stored as row vectors, which means \(X\) has to be on the left hand side (because the feature dimension needs to be the inner dimension, the one we take dot products over!).↩︎
Roughly speaking. The dot product is sensitive to the magnitude of the vectors too, not just their direction. If a particular weight vector is very large, it will have large dot products even if the data vector isn’t that similar. However, with proper regularization, most weight vectors should be within a “reasonable” range, enabling comparison.↩︎
This is specifically the “2D finalized” sharding pattern presented in Table 1.↩︎
Or in the math textbook way, “this is left as an exercise to the reader”. But really, I think working it out would be a neat exercise!↩︎
On a larger model, we’d shard even the biases and layer norms, but on this scale it’s fine not to. They’re a lot smaller than the weights.↩︎
A quick estimate counting both the communication ops (e.g. allgather, fused reduce-scatters) as well as data formatting ops.↩︎
It’s worse in the self-attention sublayer (29%), which also takes less time overall, resulting in an average of 20%. Would be a better layer to focus more for improvement!↩︎
Although, newer offerings such as the p4d.24xlarge or BM.GPU4.8 have considerably better inter-node bandwidth. At the same time, the A100 GPUs themselves are much faster, which means the inter-node bandwidth must keep up just to avoid becoming a bottleneck.↩︎