Thoughts on Chinchilla

How to (and not to) interpret the scaling laws
LLM
Scaling
Published

October 11, 2023

In a field where the “big ideas” seem to change on a weekly basis, Chinchilla (Hoffmann et al., 2022) is a standout paper: it came out a little over 18 months ago, and found then-LLMs to be massively undertrained compared to their model size, with the then dominant scaling laws (Kaplan et al., 2020) suggesting that on log-log scales, model size \(N\) be scaled ~3x (2.7x) faster than the dataset size \(D\)1.

Pre-Chinchilla models (such as Gopher (Rae et al., 2022)) tended to use the 3:1 scaling implied by Kaplan et al. (2020), whereas post-Chinchilla models use 1:1 scaling. Fig 1 from Hoffmann et al. (2022).

Chinchilla, by accounting for the effect of the learning rate scheduler, proposed that model size and dataset size should in fact be scaled in a 1:1 ratio. This meant that the (then) largest 500B+ parameter models (such as PaLM 540B (Chowdhery et al., 2022) or MT-NLG 530B (Smith et al., 2022)) were substantially undertrained2, and would need several orders of magnitude more compute than was available for that size to be “optimal”3. This triggered an industry-wide shift towards smaller models trained for longer (for a given compute budget).

Let’s dive deeper into the scaling laws, how they were derived, their implications for LLM training runs, how to (and not to) interpret them.

What’s compute optimal?

Chinchilla’s focus is on training models that are “compute optimal”: in this context, creating a model with the lowest loss for a given, fixed amount of compute \(C\). There’s two words in the phrase “compute optimal”, so let’s examine them both.

Compute

Before beginning a training run, we can estimate the total number of FLOPs (Floating Point Operations) a run will have available from just four factors:

FLOPs measure the number of individual arithmetic operations on floating point numbers a larger computation (such as a matrix multiply) will use.
  • The FLOPs/sec4 per chip.
  • The number of chips.
  • How long (in hours/days) we plan the run to be.
  • The MFU (Model FLOPs/sec Utilization) of the run. MFU is the % of maximum FLOPs/sec your chips can actually use towards training your model.5

So, suppose I have access to 64 H100s for 1 week, and initial tests reveal my MFU is 50%. Then, the total FLOPs available during this run is:

\[\frac{989 \times 10^{12}\text{ FLOPs}}{\text{second} \cdot \text{H100}} \times 604800 \text{ seconds} \times 64 \text{ H100} \times 50\% = 1.91 \times 10^{22} \text{ FLOPs}\]

Assuming the chips cost 4$/hour, this is a 43K$ training run! That is a lot of money to be spending, even on a run quite small by LLM standards. You’d want this money spent optimally (under some definition of optimal), and this is where Chinchilla comes in.

Optimal

There are two factors that influence the amount of FLOPs a run uses: The size of the model \(N\), and the number of tokens \(D\). A commonly used approximation (introduced in (Kaplan et al., 2020)) to the number of FLOPs a training run uses (the cost \(C\)) is

\[C \approx 6ND\]

A derivation of this approximation can be found in Appendix B.

Note the directly inverse relationship between model size and tokens used here: for a fixed FLOPs budget \(C\), doubling the model size \(N\) means it can only “see” half as many tokens \(D\) during training6. There’s a sweet spot to strike here: a model not so small it doesn’t have enough capacity to learn from too many tokens, but not so large it barely sees enough tokens to make use of the added capacity.

Which N and D is best? The answer is in the scaling law literature (Kaplan et al., 2020, Sec 6.1): the “optimal” \((N_{\text{opt}}, D_{\text{opt}})\) are the ones that produce a model that achieves the lowest loss on a validation set of the pretraining data, subject to the fixed cost constraint (the green star above).

A deeper dive on the findings of Kaplan et al. (2020), and how Chinchilla built on those findings can be found in Appendix A.

Chinchilla’s general approach then, is to compute \((N_{\text{opt}}, D_{\text{opt}})\) for a few small values of \(C\), and use them to extrapolate \((N_{\text{opt}}, D_{\text{opt}})\) for the \(C\) equivalent to the full training run.

Chinchilla Scaling

The full Chinchilla paper uses three different methods (fixed model size, fixed FLOPs and parametric fitting) to estimate the scaling behavior between model size and data, with similar results across all three. We focus on approach 2, where they use a set of 9 different FLOPs counts (from \(6 \times 10^{18}\) to \(3 \times 10^{21}\)). The method is as follows:

Calculating an IsoFLOP curve

First, for a given \(C\), train multiple models, varying \(N\) and \(D\) such that the FLOPs count remains \(C\). Compute the validation loss \(L\) of each model, producing a plot like this:

Note that the Chinchilla paper uses a more detailed approach to calculating \(C\) than Kaplan et al. (2020)’s \(C\approx 6ND\) approximation, explained in Appendix F. The \(C\approx 6ND\) approximation is within 10% across two orders of magnitude (Table A4), so it’s still a good mental model!

Then, fit a parabola to the points \((\log N, L)\). This allows us to predict the loss of each model of size \(N\) trained under the fixed amount of compute \(C\). The authors call this an IsoFLOP curve7, and it allows us to find \(N_{\text{opt}}\), the model size with the lowest loss for that \(C\).

This process is then repeated for each of the 9 values of \(C\), resulting in the full IsoFLOP curves plot:

Figure 3 (left) from Hoffmann et al. (2022)

Model Scaling

Each of these 9 curves above have one value of \(N_{\text{opt}}\). The authors then fit a power law \(N_{\text{opt}} = AC^a\) to the points \((C, N_{\text{opt}})\). This is where the scaling law appears: when \(a=0.49 \approx 0.5\), they obtain a very tight fit to the empirically calculated \(N_{\text{opt}}\) values. This allows them to extrapolate the best model size (66B) for Chinchilla’s full run of \(5.76 \times 10^{23}\) FLOPs, two orders of magnitude larger than the largest IsoFLOP curve of \(3\times10^{21}\) FLOPs.

Figure 3 (center) from Hoffmann et al. (2022)

Extrapolating two orders of magnitude in FLOPs is quite the jump, but there’s an exact reason \(C=5.76 \times 10^{23}\) is chosen for the full training run - the same amount of compute was used to train the preceding Gopher 280B (Rae et al. (2022)) model. According to the scaling laws calculated here, a compute optimal model for Gopher’s compute budget should be 4x smaller, trained on 4x more tokens.

This prediction is tested empirically, and it holds, validating the scaling laws: Chinchilla 70B outperforms Gopher 280B on a suite of benchmarks, as detailed in Section 4.2 in the paper.

Data Scaling

With scaling, the discussion usually centers on how to scale the model size w.r.t. increasing compute. This is because \(N\) and \(D\) are not independent: for a fixed \(C\), if you know \(N_{\text{opt}}\) you also know \(D_{\text{opt}}\)8. Fitting a similar power law, \(D_{\text{opt}} = BC^b\), the authors obtain \(b=0.51 \approx 0.5\)9, which one should see as an alternate view of the same finding above!

Figure 3 (right) from Hoffmann et al. (2022)

Doing it this way, and noticing both \(a \approx 0.5\) and \(b \approx 0.5\) does make the 1:1 ratio between model size scaling and data scaling clear.

Generality

The analysis in the core body of the paper takes place using the MassiveText dataset, a proprietary dataset also used to train Gopher. To validate the generality of these findings on other datasets, in Appendix C they reproduce this scaling behavior on two subsets of MassiveText, C4 (a public dataset first introduced in Raffel et al. (2020)) and GitHub code with 4 values of \(C\). In both, they find the constant \(a\) linking \(C\) and \(N_{\text{opt}}\) to be \(\approx 0.5\):

Table A2 from Hoffmann et al. (2022)

It is important to note that these estimates are not highly precise: the power-law is fitted on only 9 values of \(C\) in the main experiments10 (and only 4 values of \(C\) for the GitHub/C4 experiments in Appendix C). The 80% confidence intervals for \(a\) using the IsoFLOP approach is \((0.462, 0.534)\), which is still rather wide! Moreover, while a power law fit works well, we don’t know if the “true” functional form between \(C\) and \(N_{\text{opt}}\) is a power law11.

Yet, despite not being highly precise, the 1:1 scaling suggested by the results do outperform the previous 3:1 scaling - Chinchilla 70B is 4x smaller than Gopher 280B, but trained on 4x more data it outperforms the larger model, highlighting the importance of data. That the measurements of \(a\) and \(b\) replicates across three approaches (Section 3), and that the IsoFLOP approach replicates on two more datasets (C4 and GitHub, producing estimates of \(a\) and \(b\) that are each closer to \((0.5, 0.5)\) than \((0.73, 0.27)\)) provide further support that 1:1 scaling is an improvement generally.

Quadratic Compute

One intuitive (and important) conclusion from the 1:1 scaling of model size and data means, if you want a compute optimal model that’s 2x large, you need to train it on 2x many tokens. And since FLOPs count is the product of both, this means you need 4x as much compute!

Continuing that reasoning, if you want a model 5x larger, you need 25x as much compute, and so on! \(N_{\text{opt}} \propto C^{0.5}\) rewritten differently is \(C \propto N_{\text{opt}}^2\), that is you need to scale compute quadratically with model size. This is enormously expensive, and is the core reason model sizes peaked around early-2022 (pre-Chinchilla): we’re only just now doing training runs with \(C\) large enough that models of that size (500B+) are compute optimal, and future model size scaling will remain slower (compared to pre-Chinchilla) because of this quadratic factor.

Chinchilla in practice

Chinchilla, again, was an impactful paper that revealed just how wasteful training runs up till that point have been (see Timbers (2023) for historical overview). Its message is memorable and concise: scale data with model size in a 1:1 ratio.

This message does need to be interpreted with nuance! For instance, the paper comes with this table calculating the optimal number of params and tokens across a range of FLOPs values:

Table A3 from Hoffmann et al. (2022)

But strictly speaking, this is only optimal for the exact dataset + model architecture used to compute the coefficients of the scaling law, that is, \(N_{\text{opt}}=AC^a\) and \(D_{\text{opt}}=BC^b\). The argument Chinchilla makes is that \(a \approx 0.5\) and \(b \approx 0.5\) generally, across datasets; it does not make any claims as to what general values of \(A\) and \(B\) are, and they can vary from dataset to dataset!

For instance, the final Chinchilla 70B model is trained on 1.4T tokens. If we had an aggressively deduplicated version of the MassiveText dataset (such as (Abbas et al., 2023)), it is possible to have a scaling law experiment that yields 1.0T tokens as optimal, while also staying consistent with \(b \approx 0.5\), producing a plot that looks like the following:

Notice that the slopes are the same (that is, \(b =0.5\)) but the intercepts \(B\) are different

To interpret this, remember that the scaling law is fitted as \(D_{\text{opt}} = BC^b\). On a log-log plot, \(B\) acts as the intercept, while \(b\) is the slope. Chinchilla makes claims about the slope: that \(b \approx 0.5\) (and \(a \approx 0.5\) on the model size side). This means once you’ve already found a value \((N_{\text{opt}}, D_{\text{opt}})\) at a small value of \(C\), Chinchilla provides you the recipe to scale up your \(N\) on more tokens \(D\) from that exact data mixture.

But even a single \((N_{\text{opt}}, D_{\text{opt}})\) pair will be unknown to you for your exact data/architecture setup at the start of your experiments, so at the minimum you’ll want to perform ~3-5 small training runs to produce one IsoFLOP curve, to produce at least one \((N_{\text{opt}}, D_{\text{opt}})\) you can extrapolate from.

Subsequent work (such as Dey et al. (2023), Appendix D) replicated MassiveText-like dynamics in \(a, b\)12 and in \(A, B\) (finding ~20 tokens per parameter to be optimal) on a different dataset, the Pile (Gao et al., 2020). This suggests a general rule of thumb (20 tokens/param) for decoder-only, natural language models13, that have been recommended in other blogposts (such as Anthony et al. (2023)).

That said, it is important to recall the assumptions this rule-of-thumb is built on14, and to be willing to calculate a new IsoFLOP curve if any assumption is violated.

Compute optimal?

Chinchilla’s scaling laws are concerned with optimality under one definition of cost: the amount of FLOPs used in training. This translates to real world, monetary cost only if you’re paying per-unit of compute. In practice, if you’re a big lab, you likely already have a contract reserving accelerator capacity with one of the large cloud providers; the compute is already paid for, and the real cost is opportunity.

Chinchilla only tells you how to produce the “best” (lowest validation loss) model given a compute budget; the meta-calculus of how valuable each model is15, is an entirely other (and very real!) concern that it cannot answer. Moreover, there are practical instances where you may want to “overtrain” a smaller model with more data (= higher training costs), so that it is easier to serve to end users, as we see next.

Inference Costs

In industry applications, much of the cost will often not be in training, but in inference, serving the model to end users.

Since transformer inference cost is linear in model size, a model that’s 3x smaller will take 3x less FLOPs for inference16. Suppose the compute optimal model for an initial \(2.66\times 10^{21}\) FLOPs budget is \(N=2.8\text{B}\) params trained on \(D=156\text{B}\) tokens. We can always train a 1.5B parameter model for longer than compute-optimal to achieve the same performance17 as the 2.8B model. This “extra compute” we call the compute overhead. Just how large is it?

Chinchilla’s scaling laws also give us a way to quantify this! Instead of looking at IsoFLOP curves (where the FLOP is the quantity held constant on each curve) we can look at IsoLoss curves (where the loss is the quantity held constant on each curve). This specific \(N\) and \(D\) produce a loss of 2.24; we can produce a full range of \((N, D)\) values with that same loss, as shown on the left.

For this analysis, we use the formula in equation 10, appendix D.3, which is \(L(N, D) = 1.69 + \frac{406.4}{N^{0.34}} + \frac{410.7}{N^{0.28}}\)

We can also plot the IsoLoss curve for a loss of 2.24, as on the right. As we see, the “cheapest” way to achieve that loss is through the compute optimal 2.8B model (gold star). But we can also produce a model nearly half the size (1.5B, green star) if we’re willing to spend \(3.09\times 10^{21}\), or 16% more instead. When is this worth it?

Heads up: we’re talking exclusively about how to produce a smaller model with the same validation loss. But as we’ll see next, loss is not the same as downstream performance!
Payoff

Short answer: calculate the minimum number of tokens that, when inferred using the 1.5B “overtrained” vs 2.8B “optimal” model, will have saved us more in inference than the excess spent in training. Then:

\[ \begin{align*} \text{overhead} &= (2N)D_{\text{inf}}\\ \text{2.8B FLOPs} - \text{1.5B FLOPs} &= (2 \times 1.5 \times 10^9)D_{\text{inf}}\\ 3.09\times 10^{21} - 2.66\times 10^{21} &= (2 \times 1.5 \times 10^9)D_{\text{inf}}\\ D_{\text{inf}} &= \frac{3.09\times 10^{21} - 2.66\times 10^{21}}{2 \times 1.5 \times 10^9}\\ D_{\text{inf}} &= \text{140B}\\ \end{align*} \]

That is, if we’re serving more than 140B tokens, the cost savings from this ~45% smaller model will become worth it. There’s a couple things to note here:

  • The inference cost of \(D_{\text{inf}}\) tokens passed through a model of size \(N\) is \(\approx 2ND_{\text{inf}}\), not \(\approx 6ND_{\text{inf}}\). This is because we only need the costs for the forward pass, not the forward + backward pass (where the backward pass is 2x more than the forward pass).
  • This also means every training token is 3x more expensive than every inference token; we need to pass in at least 3 inference tokens for every extra training token for the cost to be worth it. Put more directly, overtraining only makes sense for models that will receive very high usage.
  • To put the 140B inference tokens into context, the number of training tokens \(D\) needed for a model of size 1.5B to achieve a loss of 2.24 (based on the loss formula above) is \(\approx \text{339B}\). This means we’ll only need to serve a fraction of the tokens needed to train the model for the compute overhead to be worth it.

Generally speaking, for models intended to be used in production, a compute overhead of upto ~100% will often be worth paying to obtain a model ~30% the size (see De Vries (2023) for this analysis).

Latency

Compute isn’t the only factor at play at inference time: latency is too! A model that’s 50% the size not only uses 50% the compute, but could also reduce the computation time by upto 50%18. If that is the difference between your user waiting for the output vs. your service being not viable, that is a constraint that needs to take priority over training a compute-optimal model.

Moreover, Chinchilla’s scaling laws tell you that you can optimally use an increased compute budget by scaling your model size and data 1:1, but you can ignore it and just increase data, keeping the model size fixed. If there is an upper bound to the size of your model (due to latency considerations or otherwise), you can always improve performance (subject to diminishing returns) by training on more tokens.

This, combined with trillion-token datasets and the ability to reuse data for upto 4 epochs (Muennighoff et al., 2023) means sub-100B parameter models (and likely even larger) are not data constrained during pre-training, and can be trained far past compute-optimal (as is the case with model families such as LLama 2 (Touvron et al., 2023)) for maximal performance under compute/latency constrained inference.

Loss \(\neq\) Performance

One last thing to note is that Chinchilla is concerned exclusively with minimizing loss on a validation set: it makes no direct claims about the actual capabilities of a model. With language models, the loss used most commonly is perplexity (Huyen, 2019). But perplexity only correlates with the behaviors we want; it itself is not the objective we care about.

Intuitively, perplexity measures between how many “options” in its vocabulary, on average, a language model is “choosing between” when generating the next token (lower is better).

This can lead to counterintuitive behaviors: for instance in the PaLM2 report (Anil et al., 2023), although a 9.5B model achieves a lower loss on \(C = 1 \times 10^{22}\) FLOPs, a 16.1B model trained with the same amount of compute (but higher loss) actually performs better on downstream evaluations.

It is always critical to remember that language modeling is a proxy objective for natural language understanding (NLU) capabilities. We’re still subject to Goodharting (Sohl-Dickstein, 2022) on whether this proxy objective (perplexity) optimizes for what we really want!

Conclusion

Understanding Chinchilla’s scaling laws as derived not only helps us better understand the assumptions made, but also enables us to recalculate them given a substantial change in dataset, model architecture, or domain. Moreover, understanding Chinchilla’s definition of a compute-optimal model helps us decide when we might not want one, and might want to overtrain a smaller model instead.

Overall, Chinchilla is much more than just training compute optimal models: it’s being able to make quantifiable tradeoffs between cost, model size and dataset size.

Acknowledgements

Deeply grateful to Jack Rae, Erich Elsen and Klara Kaleb for providing feedback on early drafts of this blogpost; it is substantially clearer and more comprehensive owing to their thoughtful recommendations. Much thanks also to Garrett Honke and Jon Deaton, with whom my many (many) conversations about language models have helped shape my own understanding. All errors are definitely my own!

Appendix A: Pre-Chinchilla Scaling

Kaplan et al. (2020) had first established scaling law behaviors in large, neural language models two years prior to Chinchilla. This was an expansive work, covering many results that are very much worth learning about, such as:

  • Transformers asymptotically outperforming LSTMs (as \(N\) grows larger), and the per-token loss going down across the context (whereas LSTMs plateau after 100 tokens), in Figure 7.
  • Ablation between various model shapes (ratio between feedforward hidden dim and model dim, number of layers, attention head dimension) and model size, finding that for a fixed \(N\), these parameters affect loss only mildly.
  • Extending the literature on critical batch size (McCandlish et al., 2018) to Transformer language models (Section 5.1).
  • Early work observing additional compute can be traded off for smaller model sizes (Figure 12).
  • Observing a conspicuous lump at \(10^{-5}\) PF-days at the transition between 1-layer to 2-layer networks (Figure 13), which subsequent work from Olsson et al. (2022) attributed to the formation of induction heads (which one-layer attention networks cannot form).

This work also fitted scaling laws between compute \(C\), and model size \(N\) and number of tokens \(D\). However, the estimates in this paper were \(a \approx 0.73\) and \(b \approx 0.27\); that is, \(\log N\) needed to be scaled up ~3x faster than \(\log D\):

Figure 3 from Kaplan et al. (2020)

However, these estimates were derived from the “early-stopped” loss (and the authors explicitly state it as such). That is, if a 1B model was trained on 100B tokens, to estimate the loss of a 1B model trained on 50B tokens, they would use the intermediate loss (at 50B tokens processed) from the 100B tokens run.

As the Chinchilla authors subsequently pointed out, using the intermediate loss value from a longer run means the learning rate schedule has not fully decayed at that point (as that happens at the end of the run). As they show in Figure A1, using a schedule with an endpoint more than 25% beyond the measurement point19 leads to clear increases in the measured loss.

To correct for this, the Chinchilla authors trained seperate models of the same model size \(N\) for each value of \(D\), making sure that the learning rate schedule fully finishes decaying when \(D\) tokens are processed. This corrected for the overestimated measured losses, yielding \(a \approx 0.5\) and \(b \approx 0.5\), yielding the now familiar 1:1 scaling ratio.

Appendix B: Deriving \(C \approx 6ND\)

A sketch for the \(C \approx 6ND\) approximation introduced in Kaplan et al. (2020) is as follows:

  1. Where \(N\) is the number of parameters, the total non-embedding FLOPs per token for the forward pass can be written out as20.

\[\begin{align*} C_{\text{forward}} &\approx 2 \cdot 2d_\text{model}n_\text{layer}(2d_\text{attn} + d_\text{ff}) + 2n_{\text{layer}}n_{\text{ctx}}d_{\text{attn}}\\ &= 2N + 2n_{\text{layer}}n_{\text{ctx}}d_{\text{attn}} \end{align*}\]

  1. \(C_{\text{forward}}\) can be described as having two terms: the first corresponding to model size, and the second corresponding to how increasing the context side \(n_{\text{ctx}}\) increases the number of FLOPs needed. For sufficiently large models with the context window sizes commonly used in training, \(2N\) is ~2 orders of magnitude larger than \(2n_{\text{layer}}n_{\text{ctx}}d_{\text{attn}}\). This allows simplifying \(C_{\text{forward}}\) to just the first term, \(C_{\text{forward}} \approx 2N\) .

  2. Since this value is per token, for \(D\) tokens this is \(C_{\text{forward}} \approx 2ND\).

  3. Now, the backward pass takes 2x as much compute as the forward pass. This is because, at each layer you need to compute the gradients for both the activations of the previous layer and the weights of that layer. Hence, \[\begin{align*} C &\approx 2ND + 4ND\\ &= 6ND \end{align*}\]

For instance, for GPT-3 175B (with \(n_{\text{layer}} = 96\), \(n_{\text{ctx}}=2048\) and \(d_{\text{attn}}=12288\)), about ~98% of \(C_{\text{forward}}\) comes from the \(2N\) term

References

Abbas, A., Tirumala, K., Simig, D., Ganguli, S., & Morcos, A. S. (2023). SemDeDup: Data-efficient learning at web-scale through semantic deduplication. https://arxiv.org/abs/2303.09540
Anil, R., Dai, A. M., Firat, O., Johnson, M., Lepikhin, D., Passos, A., Shakeri, S., Taropa, E., Bailey, P., Chen, Z., Chu, E., Clark, J. H., Shafey, L. E., Huang, Y., Meier-Hellstern, K., Mishra, G., Moreira, E., Omernick, M., Robinson, K., … Wu, Y. (2023). PaLM 2 technical report. https://arxiv.org/abs/2305.10403
Anthony, Q., Biderman, S., & Schoelkopf, H. (2023). Transformer math 101. https://blog.eleuther.ai/transformer-math/.
Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H. W., Sutton, C., Gehrmann, S., Schuh, P., Shi, K., Tsvyashchenko, S., Maynez, J., Rao, A., Barnes, P., Tay, Y., Shazeer, N., Prabhakaran, V., … Fiedel, N. (2022). PaLM: Scaling language modeling with pathways. https://arxiv.org/abs/2204.02311
De Vries, H. (2023). Go smol or go home. https://www.harmdevries.com/post/model-size-vs-compute-overhead/
Dey, N., Gosal, G., Zhiming, Chen, Khachane, H., Marshall, W., Pathria, R., Tom, M., & Hestness, J. (2023). Cerebras-GPT: Open compute-optimal language models trained on the cerebras wafer-scale cluster. https://arxiv.org/abs/2304.03208
Gao, L., Biderman, S., Black, S., Golding, L., Hoppe, T., Foster, C., Phang, J., He, H., Thite, A., Nabeshima, N., Presser, S., & Leahy, C. (2020). The pile: An 800GB dataset of diverse text for language modeling. https://arxiv.org/abs/2101.00027
Hoffmann, J., Borgeaud, S., Mensch, A., Buchatskaya, E., Cai, T., Rutherford, E., Las Casas, D. de, Hendricks, L. A., Welbl, J., Clark, A., Hennigan, T., Noland, E., Millican, K., Driessche, G. van den, Damoc, B., Guy, A., Osindero, S., Simonyan, K., Elsen, E., … Sifre, L. (2022). Training compute-optimal large language models. https://arxiv.org/abs/2203.15556
Huyen, C. (2019). Evaluation metrics for language modeling. The Gradient.
Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., & Amodei, D. (2020). Scaling laws for neural language models. https://arxiv.org/abs/2001.08361
McCandlish, S., Kaplan, J., Amodei, D., & Team, O. D. (2018). An empirical model of large-batch training. https://arxiv.org/abs/1812.06162
Muennighoff, N., Rush, A. M., Barak, B., Scao, T. L., Piktus, A., Tazi, N., Pyysalo, S., Wolf, T., & Raffel, C. (2023). Scaling data-constrained language models. https://arxiv.org/abs/2305.16264
Olsson, C., Elhage, N., Nanda, N., Joseph, N., DasSarma, N., Henighan, T., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Johnston, S., Jones, A., Kernion, J., Lovitt, L., … Olah, C. (2022). In-context learning and induction heads. Transformer Circuits Thread.
Rae, J. W., Borgeaud, S., Cai, T., Millican, K., Hoffmann, J., Song, F., Aslanides, J., Henderson, S., Ring, R., Young, S., Rutherford, E., Hennigan, T., Menick, J., Cassirer, A., Powell, R., Driessche, G. van den, Hendricks, L. A., Rauh, M., Huang, P.-S., … Irving, G. (2022). Scaling language models: Methods, analysis & insights from training gopher. https://arxiv.org/abs/2112.11446
Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., & Liu, P. J. (2020). Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of Machine Learning Research, 21(140), 1–67. http://jmlr.org/papers/v21/20-074.html
Smith, S., Patwary, M., Norick, B., LeGresley, P., Rajbhandari, S., Casper, J., Liu, Z., Prabhumoye, S., Zerveas, G., Korthikanti, V., Zhang, E., Child, R., Aminabadi, R. Y., Bernauer, J., Song, X., Shoeybi, M., He, Y., Houston, M., Tiwary, S., & Catanzaro, B. (2022). Using DeepSpeed and megatron to train megatron-turing NLG 530B, a large-scale generative language model. https://arxiv.org/abs/2201.11990
Sohl-Dickstein, J. (2022).Too much efficiency makes everything worse: overfitting and the strong version of Goodhart’s law . https://sohl-dickstein.github.io/2022/11/06/strong-Goodhart.html.
Timbers, F. (2023). Five years of progress in GPTs. In Five years of progress in GPTs - by Finbarr Timbers. Artificial Fintelligence. https://finbarrtimbers.substack.com/p/five-years-of-progress-in-gpts
Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., Bikel, D., Blecher, L., Ferrer, C. C., Chen, M., Cucurull, G., Esiobu, D., Fernandes, J., Fu, J., Fu, W., … Scialom, T. (2023). Llama 2: Open foundation and fine-tuned chat models. https://arxiv.org/abs/2307.09288

Footnotes

  1. That is, \(\log N = 2.7\log D + K\), where K is some constant. Equivalently, \(N \propto D^{2.7}\). In concrete terms, for every 2.7 orders of magnitude increase in model size \(N\), we only need to increase dataset size \(D\) by one order of magnitude.↩︎

  2. That is, they needed to be trained on much more data for that size to be compute optimal.↩︎

  3. Said differently, for the total compute available to the teams then, they could’ve obtained a lower validation loss with a smaller model trained on more data.↩︎

  4. FLOPs/sec is often written as FLOPS (with a capital S for second), but this can be confusing, so I write it out explicitly here.↩︎

  5. In more depth, this is the FLOPs/sec used towards the computation of the training job itself (disregarding any re-computations such as activation checkpointing) divided by the peak FLOPs/sec of the hardware. Chowdhery et al. (2022) first introduced this, and this can be quickly calculated by running your job on the cluster for a few minutes (vs days/weeks for the full run).↩︎

  6. Again, you can always double the model size and use the same number of tokens as before, but you now need 2x the compute! Chinchilla’s analyses are about how to maximally use a fixed amount of compute.↩︎

  7. In that the FLOPs (\(C\)) is the quantity being held constant for each point on this curve.↩︎

  8. Intuitively, if you know the size of the model, and amount of compute used to train it, you can quickly reverse calculate how many tokens the model would need to be “trained on” to hit that compute budget.↩︎

  9. The actual 80% confidence intervals for \(a\) is \((0.462,0.534)\), so it’s not unreasonable to round it to 0.5 for simplicity.↩︎

  10. Understandably so: computing even one of the IsoFLOP curves means training multiple models with a fair amount of compute each, which isn’t cheap!↩︎

  11. On those lines, here’s a neat GitHub implementation using PySR to directly regress scaling laws from data (without assuming a power law fit first)↩︎

  12. 1:1 data:model scaling, as expected↩︎

  13. without needing to produce IsoFLOPs curves for each new dataset↩︎

  14. That the model is decoder-only, natural language, with a mixture similar to MassiveText/the Pile.↩︎

  15. These are questions such as “Given our compute reservations, is a 3B model that can be deployable in 6 weeks more valuable than a 20B model deployable in 3 months?”.↩︎

  16. And hence, properly optimized at scale, will be 3x cheaper.↩︎

  17. that is, the same validation loss↩︎

  18. Assuming we still have the same amount of hardware, and are not too bottlenecked on the generation of the output tokens.↩︎

  19. In other words, a schedule that has not fully decayed at the measurement point.↩︎

  20. Note this leaves out the compute used for biases, nonlinearities and layer norms, which are a tiny fraction of total compute.↩︎

Citation

BibTeX citation:
@online{shafkat2023,
  author = {Shafkat, Irhum},
  title = {Thoughts on {Chinchilla}},
  date = {2023-10-11},
  url = {https://irhum.github.io/blog/chinchilla/},
  langid = {en}
}
For attribution, please cite this work as:
Shafkat, I. (2023, October 11). Thoughts on Chinchilla. https://irhum.github.io/blog/chinchilla/