Neural Language Models

We look at language models parametrized by neural networks, and how they’re capable of near transfer, generalizing to sequences similar to (but not the exact same) as those in their training sets.

November 14, 2022

At their core, language models (LMs) are simply functions that assign a probability to a snippet of text. A common variant are autoregressive LMs: functions which produce a probability for the “next token” conditioned on some already written text, that is \(p(\text{token} \mid \text{context})\). Such a models can be immensely powerful, as many tasks can be re-written to fit into a “next token prediction” problem, for instance:

If we had an infinite amount of text data, observing every possible sentence proportional to their “true probability”, estimating these probabilities is straightforward.

Unfortunately, we do not have an infinite corpus. Even with a massive text corpus, we likely won’t observe most valid sequences of say, 100 words, even once since there’s an exponentially large number of sequences.

This post focuses on how neural networks allow us to build language models that can produce “good” next token probabilities for sequences that are “similar” (but not the exact same) as sequences in the training set, as first introduced in Bengio et al. (2003). Reusing language from human learning, this is akin to near transfer (Schunk, 2011, p. 320), where there is substantial overlap between the “training” and “test” contexts.

Levels of Analysis

Within any sufficiently complicated technical system, there are multiple levels that interact with each other, but are also distinct from each other. Disambiguating between them can provide a clean mental model to reason about them. I find it helpful to apply Marr’s levels of analysis (Marr & Poggio, 1976) in such settings, and applying them here:

  • Computation: This is the abstract computation any language model does, regardless whether it’s using a lookup table, a transformer or a bio-inspired computer. Specifically for an autoregressive language model, it is the estimation of the conditional probabilities for the next token conditioned on the current tokens.

  • Representation: There’s multiple distinct ways to parametrize a language model. In this post we’ll explore and compare two: explicitly parametrizing each conditional probability, and using a neural network. We also need to decide how to represent raw text as individual tokens (is a single word a token? a single character?) and we’ll explore this too.

  • Implementation: Once we’ve decided on a representation for our functions and data, we actually need to implement them in code to get them running on hardware. This is the Python/JAX code we use in this article, and in turn all the kernels they execute to carry out the computation.

This breakdown is subjective, and strongly dependent on the problem you’re looking at. If we were optimizing raw PTX instructions instead: the Python/JAX code may define the computation, the PTX being the representation and the GPU architecture-specific hardware operations the implementation. One person’s implementation level can easily be someone else’s computation level.

This blogpost focuses primarily focuses at the computational level, on what a language model is. There’s two distinct parts:

  • In Part A, we take a deep dive on what it means to estimate a discrete probability distribution, and the combinatorial explosion in the number of outcomes that makes this challenging.
  • In part B, we step between the computational and representational levels, and learn that a language model parametrized by a neural network stores next token probabilities implicitly in its weights.

Part A: Discrete Distributions

Maximum Likelihood Estimation

Suppose we receive a coin, and we wish to find out if it’s biased or not. We run 100 tosses, and 62 of them come up heads, so we figure it’s biased at 62% heads, 38% tails.

What just happened here? To rewind, the scenario looks like this:

  • Support: We have a random variable \(X\), with exactly two possible1 outcomes: heads or tails.
  • Parametrization: \(X\) is a discrete random variable, and each of its outcomes has a true probability value, \(p(0)\) and \(p(1)\). We don’t know what these values are, so we parametrize a two-valued discrete probability distribution, with \(p_\theta(1) = \theta\) and \(p_\theta(0) = 1 - \theta\).
  • Sampling: Although we don’t know the distribution, we can generate samples from it by tossing the coin repeatedly. We gather a “dataset” of 62 heads and 38 tails.
  • Estimation: We can then use these samples to estimate the value of the parameter \(\theta\). Specifically, we can find the value of \(\theta\) that maximizes the likelihood of the outcome we observed.

Given a value of \(\theta\), we can compute the probability we’d see a given outcome:

  • The probability of observing one heads is \(\theta\).
  • Since each observation is independent, the probability of observing one heads and two tails is \(\theta (1 - \theta)(1-\theta) = \theta(1-\theta)^2\)
  • For the overall observation here, of 62 heads and 38 tails, we have \(\theta\theta...(1-\theta)(1-\theta)=\) \(\theta^{62}(1-\theta)^{38}\)
Independence between individual observations is a key assumption; if they’re not independent, \(p(\text{obs}) = p(\text{obs}_1)p(\text{obs}_2)...\) doesn’t hold.

However, \(\theta\) is a variable for which we need to estimate a value. One process to do so, is to see the above term as a function of theta \(L(\theta) = \theta^{62} (1 - \theta)^{38}\), known as the likelihood function. We can then find the value of \(\theta\) that maximizes the likelihood function is a process, aptly called maximum likelihood estimation, or MLE. This is the value of \(\theta\) that, when used in \(p_\theta(\cdot)\) will result in the highest2 probability being assigned to our observation.

Note that the likelihood function as is will produce extremely tiny values; for numerical stability, we usually minimize the negative log-likelihood, which here is:

\[\begin{align} -\ln L(\theta) &= -\ln(\theta^{62} (1 - \theta)^{38})\\ &= -62 \ln (\theta) - 38 \ln(1 - \theta)\\ \end{align}\]

Visualizing both, we see they are maximized and minimized (respectively) at the same value, \(\theta=0.62\), the same as our earlier “back of the hand” estimate:

import numpy as np
import matplotlib.pyplot as plt

def prob(p):
    return p**62 * (1-p)**38

def log_prob(p):
    return -62*np.log(p) - 38*np.log(1-p)

fig, ax = plt.subplots(figsize=(9, 4), ncols=2)
xs = np.linspace(0.3, 0.9, 300)
ax[0].plot(xs, prob(xs))
ax[0].set_xlabel('Value of theta')
ax[0].set_xlim(0.3, 0.9)
ax[0].plot([0.62], [prob(0.62)], marker='x', markersize=10, color="royalblue")

ax[1].plot(xs, log_prob(xs))
ax[1].set_xlabel('Value of theta')
ax[1].set_ylabel('Negative Log Likelihood')
ax[1].set_xlim(0.3, 0.9)
ax[1].plot([0.62], [log_prob(0.62)], marker='x', markersize=10, color="royalblue")

Now suppose we learn the true value of \(p(1) = 0.6\). Note that \(0.6\) is in fact, less likely than our estimate of \(p_\theta(1) = \theta = 0.62\):

def log_prob(p):
    return -62*np.log(p) - 38*np.log(1-p)

fig, ax = plt.subplots(figsize=(9, 4), ncols=1)
xs = np.linspace(0.55, 0.65, 300)

ax.plot(xs, log_prob(xs))
ax.set_xlabel('Value of theta')
ax.set_ylabel('Negative Log Likelihood')
ax.set_xlim(0.55, 0.65)
ax.plot([0.62], [log_prob(0.62)], marker='x', markersize=10, color="royalblue")
ax.plot([0.6], [log_prob(0.6)], marker='x', markersize=10, color="red")

This is because we estimated \(\theta\) with the value that maximizes the probability of the observed samples. In the limit of an infinite number of samples, the likelihood function will be maximized at a value of \(\theta\) that produces the true probabilities. But in practice, there will be variance in the number of heads in 100 samples (we may have 61 heads, or 58, or…); and in turn, variance in our estimates.

Discrete distributions as vectors

We can write the probability mass function of the “coin toss” random variable in the following vector. Each entry simply stores the probability for one of the possible outcomes:

In general, for a discrete random variable with \(V\) possible outcomes, there are \(V-1\) free parameters. This is because all the probabilities must sum to \(1\):

  • For \(V=2\), as in the case above, we have the probability for one outcome be \(\theta\). The other must then be \(1 - \theta\) for both to sum to \(1\).
  • For \(V=3\), we can have probabilities for two of the outcomes specified by \(\theta_1\) and \(\theta_2\). The third outcome then must have the probability \(1 - \theta_1 - \theta_2\).
  • For \(V=n\), we have the probabilities for the first \(n-1\) outcomes specified by \(\theta_i\), and the final one by \(1 - \sum_{i=1}^n \theta_i\).

There are potentially detrimental consequences when we “tie” the estimates of the probabilities of different outcomes, as we see next.

Constrained estimates

Suppose we have a discrete random variable with three possible outcomes (\(V=3\)). The vector encodes these three outputs for the true probability mass function \(p(\cdot)\)3:

As before, we don’t know \(p(\cdot)\), so we approximate it with \(p_\theta(\cdot)\). To find the best value of the parameters \(\theta\), we use maximum likelihood estimation. We define two parametrizations \(p_\theta(\cdot)\):

  • one that is free with \(V-1 = 2\) parameters (\(\theta_1\) and \(\theta_2\)). That is, we can independently update \(p_\theta(a)\) without affecting \(p_\theta(b)\).4
  • another that is constrained to only have one parameter (\(\theta_c\)). Note that \(p_\theta(a)\) and \(p_\theta(b)\) are now tied; updating one will also change the other.

As like the true \(p(\cdot)\), we can also write the outcomes of these approximations \(p_\theta(\cdot)\) in a vector:

Suppose we draw 1000 samples, receiving 38 a’s, 7 b’s and 955 c’s. For each parametrization, we find the values of the parameters that maximize the likelihood of receiving this outcome:

def log_prob(theta0, theta1):
    return -(38*np.log(theta0) + 7*np.log(theta1) + 955*np.log(1-theta0-theta1))

def log_prob_constrained(theta):
    return -(38*np.log(theta) + 7*np.log(theta**2) + 955*np.log(1-theta-theta**2))

theta0_s = np.linspace(0.02, 0.06, 100)
theta1_s = np.linspace(0.001, 0.02, 100)
mesh_b0, mesh_b1 = np.meshgrid(theta0_s, theta1_s)
mesh_y = log_prob(mesh_b0.reshape(-1), mesh_b1.reshape(-1)).reshape(100, 100)

fig, ax = plt.subplots(figsize=(9, 4), ncols=2)
contour = ax[0].contourf(mesh_b0, mesh_b1, mesh_y, levels=30, cmap="Blues")
contour_line = ax[0].contour(mesh_b0, mesh_b1, mesh_y, levels=[203.3, 207.19], cmap="Blues_r", vmin=202, vmax=220)
ax[0].clabel(contour_line, inline=True, fontsize=10)
ax[0].plot([0.038], [0.007], marker='x', markersize=10, color="royalblue")
ax[0].set_yticks(np.arange(0.0, 0.02, 0.003))
ax[0].set_ylim(0.001, 0.02)
ax[0].set_title("(Free) Lowest NLL = {:.2f}".format(log_prob(0.038, 0.007)))

xs = np.linspace(0.03, 0.06, 300)
ax[1].plot(xs, log_prob_constrained(xs))
ax[1].plot([0.047], [log_prob_constrained(0.047)], marker='x', markersize=10, color="royalblue")
ax[1].set_ylabel("Negative Log Likelihood")
ax[1].set_title("(Constrained) Lowest NLL = {:.2f}".format(log_prob_constrained(0.047)))
ax[1].set_xlim(0.03, 0.06)
ax[1].set_yticks(np.arange(204, 215, 2))
ax[1].set_ylim(204, 214.3)

On the left, in the “free” parametrization, the minimum is reached at \(\theta_1=0.038, \theta_2=0.007\), just as would be expected by estimating from the proportions directly (\(\theta_1=\frac{38}{1000}, \theta_2=\frac{7}{1000}\)). In general, this is true for any number of outcomes \(V\); the maximum likelihood estimate for each probability \(\theta_i\) is the proportion of times said outcome appears in the entire dataset.

On the right however, with only one free parameter under the above parametrization, the minimum is reached at \(\theta_c=0.047\). This corresponds to \(p_\theta(a)=0.047, p_\theta(b)=0.047^2=0.002\). Not only is this a bad estimate of the true probabilities (where \(p(a)=0.03\)), it’s far even from the estimate that was reached by the free model (where \(p_\theta(a)=0.038\)).

The lowest NLL reached here is \(207.19\). By plotting a contour back on the free parametrization (left), we can find parameter pairs (\(\theta_1, \theta_2\)) with a lower NLL. With the free parametrization, \(\theta_1\) could be anywhere between 0.025 (2.5%) and 0.055 (5.5%) and still have a higher likelihood for the observation than the best value of \(\theta_c\).

A better constraint

However, suppose we knew more about the problem: this is actually the annual failure rate of a widget, with a, b and c meaning “big gear failure”, “small gear failure” and “no failure”. Moreover, prior physical knowledge of the system informs us that the big gear should fail at 3x the rate of the smaller one, that is: \(p(a)=3p(b)\). This inspires the following parametrization, where this constraint holds for all values of \(\theta_c\):

By having \(p_\theta(a)=\theta_c\) and \(p_\theta(a)=\frac{\theta_c}{3}\), we guarantee the constraint \(p_\theta(a)=3p_\theta(b)\) is met for all values of \(\theta_c\). As before, we find the value of \(\theta_c\) with the lowest NLL:

def log_prob_constrained(theta):
    return -(38*np.log(theta) + 7*np.log(theta**2) + 955*np.log(1-theta-theta**2))

def log_prob_constrained2(theta):
    return -(38*np.log(theta) + 7*np.log(1/3*theta) + 955*np.log(1-4/3*theta))

fig, ax = plt.subplots(figsize=(9, 4), ncols=1)
xs = np.linspace(0.03, 0.06, 300)

ax.plot(xs, log_prob_constrained(xs))
ax.plot(xs, log_prob_constrained2(xs))
ax.set_xlabel('Value of theta_c')
ax.set_ylabel('Negative Log Likelihood')
ax.set_xlim(0.03, 0.06)
ax.plot([0.047], [log_prob_constrained(0.047)], marker='x', markersize=10, color="royalblue", label="old min NLL={:.2f}".format(log_prob_constrained(0.047)))
ax.plot([0.0338], [log_prob_constrained2(0.0338)], marker='x', markersize=10, color="orange", label="new min NLL={:.2f}".format(log_prob_constrained2(0.0338)))

The optimal of \(\theta_c\) then is \(\theta_c=0.034\), corresponding to \(p_\theta(a)=0.034\), \(p_\theta(b)=\frac{0.034}{3}=0.011\).

Not only is this new constrained estimate \((0.034)\) better than the previous one, the probabilities are closer to the true ones \((0.03)\) than even with the most flexible model \((0.038)\). This is because we’ve correctly incorporated our prior knowledge that \(p_\theta(a)=3p_\theta(b)\). To summarize:

  • A parametrization for a discrete distribution with \(V\) outcomes, that is guaranteed to be capable of converging to the true distribution (given an infinite amount of data) has \(V-1\) parameters; one parameter per outcome.
  • Bad constraints will permanently bias the model; even with an infinite amount of data its ability to represent the correct probabilities will be constrained (in the first constrained example, we have \(p_\theta(b) = (p_\theta(a))^2\), even though it is false)
  • However, a truthful constraint (here, \(p_\theta(a)=3p_\theta(b)\)) will converge to the true parameters with an infinite amount of data. Moreover, it will converge faster, since the constraints will reduce the impact of variance in the samples.

Joint Distributions as tensors

Let’s now look at a situation with two random variables. Suppose we now have the following setup:

  • I have two time slots each day, each of which I can fill up with either reading or hiking.
  • The activity done in the first and second time slots are \(X_1\) and \(X_2\) respectively.

Then, let the joint distribution5 of \(X_1\) and \(X_2\) be the following:

Note that these variables are not independent: for instance, if we learn hiking is the second activity of the day, we can deduce hiking could not have been the first activity that day, as \(p(\text{hiking}_1, \text{hiking}_2) = 0\). Estimation works the same as previously: if 21 of 200 observed days are “read first, then hike”, then \(p_\theta(\text{read}_1, \text{hiking}_2) = \frac{21}{200}\).

Moreover, we note that each of the \(X\)’s has the same \(V=2\) outcomes: \(H\) or \(T\). We could then convert the “vector” above into the following matrix:

In general, with \(m\) random variables, each with \(V\) outcomes individually, there are \(V^m\) total possible joint outcomes. Here, this is \(2^2 = 4\) outcomes. One can imagine storing these joint outcomes in a tensor with \(m\) axes, each with \(V\) dimensions. Here we store these 4 outcomes in a \(2\times 2\) matrix; for 4 variables with 3 outcomes each we could use a \(3\times3\times3\times3\) tensor (with 81 total “outcomes”).

Note that this grows exponentially in the number of outcomes. With \(V=10\) and \(m=10\), we have \(10^{10}\) (ten billion) possible outcomes. Each of these is an outcome we need to observe (multiple times) to estimate the probabilities for. Compared to the previous \(V - 1\) parameters for the distribution of a single random variable, we now have \(V^m - 1\) parameters for the joint distribution of \(m\) random variables.

This exponential growth (and in turn, the number of samples required) makes directly estimating probabilities for anything but the simplest of joint distributions intractable. Maybe conditional probabilities can help?

Conditional Distributions as vectors

Conditional distributions also fit neatly into this tensor form. The chain rule in probability tells us that any joint probability can be decomposed into the product of conditional probabilities. Extending the previous example:

\[p(\text{hiking}_1, \text{reading}_2) = \textcolor{NavyBlue}{p(\text{reading}_2 \mid \text{hiking}_1)}\textcolor{Orange}{p(\text{hiking}_1)}\]

Here, the joint probability of “hiking then reading” is the product of the marginal probability of hiking first and the conditional probability of reading after hiking. We can represent this factorization in vector form as follows:

To analyze this factorization:

  • We’ve initially stored the joint probabilities \(p_\theta(\cdot, \cdot)\) in a matrix form. This has \(V^m = 2^2 = 4\) outcomes, and so \(V^m - 1 = 3\) parameters.
  • We can factorize that matrix into a vector \(p_\theta(\cdot)\) representing the marginal probability, and another vector \(p_\theta(\cdot \mid \text{hiking}_1)\) representing the conditional probability.
  • Each of these two vectors has \(V=2\) entries, and so \(V - 1 = 1\) parameters. With 2 vectors, we have 2 parameters total.

It appears the factorized version has fewer parameters (2 vs 3) than the original. There’s just one problem: it’s incomplete! We’re only looking at the conditional probability table for where \(X_1 = \text{hiking}_1\); there’s another one for \(X_1 = \text{read}_1\):

With three vectors (one marginal, two conditional), we have 3 parameters, the same as the original joint probability matrix. This unfortunately is true in general as well; for \(m\) discrete random variables (each with \(V\) outcomes):

  • the conditional probability vector of \(p(\cdot \mid x_{1:k-1})\) has \(V\) entries (\(V-1\) parameters), as \(X_k\) only has \(V\) outcomes.
  • However, there are \(V^{k-1}\) such tables, one for every possible combination of prior values \(x_{1:k-1} = x_1, ..., x_{k-1}\).
  • The total number of parameters for a variable conditioned on \(k-1\) “already observed” outcomes then is \(V^{k-1}(V-1)\).

General Case

The general chain rule for a joint distribution with \(m\) variables is:

\[p(x_1, x_2,..., x_m) = p(x_m \mid x_{1:m-1})p(x_{m-1} \mid x_{1:m-2})...p(x_2\mid x_1)p(x_1)\]

Using the result in the previous section, let’s compute the number of free parameters in the distributions in the product above:

  • For the marginal \(p(\cdot)\), there are \(V-1\) parameters.
  • For \(p(\cdot \mid x_1)\), conditioned on one outcome there are \(V(V-1)\) parameters.
  • For \(p(\cdot \mid x_{1:m-1})\), conditioned on \(m-1\) outcomes there are \(V^{m-1}(V-1)\) parameters.

Summing up the number of parameters for all \(m\) distributions in the product, we have:

\[\begin{align} N &= (V - 1) + V(V-1) + V^2(V-1)+...+V^{m-1}(V-1)\\ &= \sum_{k=0}^{m-1} V^k(V-1)\\ &= V^m - 1 \end{align}\]

We’re back at \(V^m - 1\) parameters, same as in the joint distribution case. Breaking the joint probability matrix into conditional probability vectors changes the parametrization, but doesn’t change the number of parameters.

Independence as constraints

Earlier with the widgets example, we used context to reduce the number of parameters needed. Now suppose here, we knew that each of the \(X_i\) was a coin toss (with 60% probability of heads as before), and we had a sequence of \(m=30\) coin tosses. There’s \(2^{30}\) (\(\approx\) a billion) possible sequences of coin tosses here. And you’d need to observe each of these billion sequences multiple times to estimate the proportion for each sequence compared to the total.

But it’s silly to store a separate probability for each unique sequence (e.g. \(\text{HTH...T}\)): we know the coin tosses are independent from each other, and identically distributed. We could simplify the chain rule from before as follows:

\[\begin{align} p_{X_1, X_2,...,X_m}(x_1, x_2,..., x_m) &= p_{X_m}(x_m \mid x_{1:m-1})p_{X_{m-1}}(x_{m-1} \mid x_{1:m-2})...p_{X_2}(x_2\mid x_1)p_{X_1}(x_1)\\ &=p_{X_m}(x_m)p_{X_{m-1}}(x_{m-1})...p_{X_2}(x_2)p_{X_1}(x_1)\\ &=p(x_m)p(x_{m-1})...p(x_2)p(x_1)\\ \end{align}\]

Note that here we explicitly subscript each \(p\) with the random variable it is of. This is to clearly demonstrate when we drop them in the third line. We do so as we integrate the knowledge all variables are identically distributed (and hence have the same \(p\).)

This simplifies things greatly:

  • Previously, each \(p(x_k \mid x_{1:k-1})\) required \(V^{k-1}(V-1)\) parameters to represent: each conditional probability vector has \(V-1\) parameters, and there’s \(V^{k-1}\) such vectors as there are \(V^{k-1}\) combinations of \(k-1\) previous outcomes.
  • But you don’t actually need a separate table for each combination of past outcomes: independence means \(p(x_k \mid x_{1:k-1}) = p(x_k)\). Each of the \(V^{k-1}\) vectors are the same vector with 2 outcomes, which means there’s only \(V-1=1\) parameter here.
  • Moreover, each of the \(X_i\)’s are identically distributed (since it’s the same coin), that is, they all share that \(1\) parameter.

Overall, we see that as the coin tosses are independent, one parameter is sufficient to express the entire joint probability for any of the \(2^{30}\) possible sequences. In general, knowing (conditional) independence between variables helps greatly: if (some) of the previous variables don’t change the probability, you don’t need a separate conditional probability vector for each of those previous combinations of outcomes. Here, we cut \(V^m-1\) free parameters to \(V-1=2-1=1\) parameter.

Part B: Language Models

We now have a detailed understanding of the joint probability distribution of \(m\) discrete random variables. Let’s now apply that knowledge, by looking at the scenario specific to language modeling:

  • Support: We have a sequence of random variables \((W_1, W_2,...,W_m)\), where each variable represents a “word”, more commonly called a token. Each token can be one of \(V\) possible values, where \(V\) is the size of the “vocabulary”.
  • Parametrization: There is a true probability distribution the samples in our datasets are generated from. For example, the sequence (‘I’, ‘like’, ‘cats’) is sampled with probability \(p(\text{'I', 'like', 'cats'})\). However, we do not know \(p\), so we use an approximation \(p_\theta\).
  • Sampling: We assume any text datasets we have to be samples from this true distribution.
  • Estimation: To estimate the value of the parameter \(\theta\), we find values that maximize the likelihood of our dataset.

In brief, a language model is simply a function that assigns probabilities to length-\(m\) sequences of text. To reinforce this, let’s look at an actual dataset.


We use the dataset from Karpathy (2015), which is a 4.6MB file containing all the works of Shakespeare. Here’s the first 75 characters in the dataset:

import requests

url = ''
text = requests.get(url).text
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, s'


Before building a language model, we must choose how to represent our raw text as individual “tokens” in a “sequence”. This requires concretely defining what the vocabulary for each token is:

  • Word-level: If each entire English word is a token, then the size of the vocabulary \(V\) grows quickly. This dataset alone has over 60,000 unique words.
  • Character-level: If each individual character is a token, \(V\) is small (here, there’s only 67 unique characters, including whitespaces like \n.). However, for a given amount of information, \(m\) needs to be much larger since there’s more tokens. For instance, “the weather in Berlin” is 4 words, but 21 characters.
  • Subword-level: The most common way to build language models today is to use subword tokenization, such as using the SentencePiece tokenizer (Kudo & Richardson, 2018). These keep small words intact, but break up longer words into smaller “subwords”.

For this article we’ll be tokenizing at the character level: it’s straightforward, and will allow us to directly visualize the probability matrices. Breaking up the first 75 characters into individual tokens, we have:

print([c for c in text[:75]])
['F', 'i', 'r', 's', 't', ' ', 'C', 'i', 't', 'i', 'z', 'e', 'n', ':', '\n', 'B', 'e', 'f', 'o', 'r', 'e', ' ', 'w', 'e', ' ', 'p', 'r', 'o', 'c', 'e', 'e', 'd', ' ', 'a', 'n', 'y', ' ', 'f', 'u', 'r', 't', 'h', 'e', 'r', ',', ' ', 'h', 'e', 'a', 'r', ' ', 'm', 'e', ' ', 's', 'p', 'e', 'a', 'k', '.', '\n', '\n', 'A', 'l', 'l', ':', '\n', 'S', 'p', 'e', 'a', 'k', ',', ' ', 's']

Then, converting each unique token into a unique integer, we have:

vocab = sorted(set(text))
char2idx = {c: i for (i, c) in enumerate(vocab)}
tokens = [char2idx[c] for c in text]

[18, 49, 58, 59, 60, 1, 15, 49, 60, 49, 66, 45, 54, 10, 0, 14, 45, 46, 55, 58, 45, 1, 63, 45, 1, 56, 58, 55, 43, 45, 45, 44, 1, 41, 54, 65, 1, 46, 61, 58, 60, 48, 45, 58, 6, 1, 48, 45, 41, 58, 1, 53, 45, 1, 59, 56, 45, 41, 51, 8, 0, 0, 13, 52, 52, 10, 0, 31, 56, 45, 41, 51, 6, 1, 59]

Note that our input pipeline looks like this:

\[\text{raw text} \rightarrow \text{tokens} \rightarrow \text{integers}\]

Now suppose we’ve set out to build a language model with context length \(m=1024\); that is, a model can assign a probability to all character sequences of length 1024. Even here, we run into combinatorial explosion: with \(V=67\) characters and context length \(m=1024\), there’s \(67^{1024}\) possible sequences.

Most of those sequences would have zero probability (on a smaller scale, “big city”6 is a valid sequence of length 8, while “xcsazmad” is not), but even among the valid sequences, you’d need a truly staggering number of samples to even observe every valid sequence. Since this is infeasible, let’s try simplifying.

Chain Rule

Using the chain rule as discussed earlier, we can represent any joint probability as a product of conditional probabilities. With \(m=1024\), we have:

\[\begin{align} p(w_1, w_2,..., w_{1024}) &= p(w_{1024} \mid w_{1:1023})p(w_{1023} \mid w_{1:1022})...p(w_2\mid w_1)p(w_1)\\ &= \prod_{i=1}^{1024} p(w_i \mid w_{1:i-1})\\ \end{align}\]

In this case, it means iteratively computing the probability of each token conditioned on all the previous tokens. On its own, this is of little help. As before, although we’re “splitting up” the joint probability tensor, there’s still \(V^{1024} - 1\) possible variables in total, since each “next token prediction” would need a unique conditional probability vector for each combination of previous words \(w_{1:i-1}\).


What if we pretended, just like coin tosses, each token was independent of its previous tokens? This is not true: We likely have \(p(\text{'h'}\mid \text{'t'})\) greater than \(p(\text{'h'})\), as “th” appears in “the”, one of the most common words in the English language. Knowing prior tokens definitely changes the probability of the next token.

But continuing forward with this naive assumption, we write:

\[\begin{align} p(w_1, w_2,..., w_{1024}) &= p(w_{1024} \mid w_{1:1023})p(w_{1023} \mid w_{1:1022})...p(w_2\mid w_1)p(w_1)\\ &\approx p(w_{1024})p(w_{1023})...p(w_2)p(w_1) \end{align}\]

That is, the joint distribution is the product of the marginal distributions. Since each marginal \(p(\cdot)\) has \(V=67\) outcomes, they have \(66\) parameters. We make another naive assumption: each of the \(X_i\)’s are identically distributed7. This means these 66 parameters are shared across all 1024 terms in the product above.

Note that with these two assumptions, we’re no longer estimating the probability of length \(m=1024\) sequences, but rather \(m=1\). We’re then assuming we can approximate the probability of a longer sequence with the products of these individual probabilities.


With the two naive assumptions above (independence + sharing across timesteps) we have 66 parameters that need to be estimated. We know the maximum likelihood estimate here from the previous section: count the proportion of times an outcome happened, among the total.

Let’s first split the corpus (of 4,573,338 tokens) into a training set (first 4 million) and a validation set (remaining 573,338).

train_tokens, valid_tokens = tokens[:4000000], tokens[4000000:]

Then, to estimate the parameters, we count the number of times the token appears and divide by the total8:

# Create "initial" counts
counts = np.zeros(shape=(len(vocab),)) + 0.1

# Loop over tokens in dataset
for token in train_tokens:
    counts[token] += 1

# Normalize to 1
params = counts / counts.sum()

We can then visualize this one-variable probability distribution, in the following bar graph:

idxs = np.argsort(params)[::-1]
fig, ax = plt.subplots(figsize=(9, 4))[vocab[idx].replace("\n","\\n") for idx in idxs], params[idxs])
ax.set_ylabel("Probability of outcome")

As we see, the most “likely” token is ' ', followed by e, with the lowest probability one being $.


How good of a model of language is this? Even though very constrained, we can see it’s correctly “learned” a qualitative aspect of English: that e is the most commonly occurring letter. But how do we know these frequencies are reliable, and not due to idiosyncrasies in the first 4 million tokens9?

One metric commonly used for language models is perplexity, which works as follows:

  • For each token, compute the log probability of that token under the model.
  • Take the mean of the log probability across all tokens.
  • Take the negative exponential of this mean.

In code, we have10:

Note that perplexity maxes out at \(V\) (here, 67) when every token has probability \(\frac{1}{V}\) (that is, uniform) under the model. It has a minimum at 1, when the model assigns a probability of 1 to every token; that is, it perfectly predicts the sequence11.
def perplexity_unigram(probs_vec, tokens, start=7):
    log_probs = [np.log(probs_vec[token]) for token in tokens[start:]]
    return np.exp(-np.mean(log_probs))

Computing the perplexity over the training and test sets, we have train_ppl=27.53, and valid_ppl=27.14. One way to interpret this value is that the model would be “choosing” between 27.14 outcomes (out of 67) at every step if asked to reproduce the validation sequence (lower is better).


Now that we’ve estimated the parameters, we can generate a new sequence. Let’s generate a new sequence of length 30:

sequence = ''

for _ in range(30):
    sequence += np.random.choice(vocab, p=params)

'dweIdi ta rru mSd if  \nt aeoe '

As we can see, this isn’t very English like: after all, it’s treating every token as a 67-way coin toss, with no regard for the previous tokens. Let’s make this more realistic.


Instead of assuming each token is independent, let’s assume that a token \(w_i\) and tokens \(w_{i-2}, w_{i-3}...\) are conditionally independent, given token \(w_{i-1}\). This means if we know the immediately previous token, knowing tokens more previous will not change the probability. We have the approximation:

\[\begin{align} p(w_1, w_2,..., w_{1024}) &= p(w_{1024} \mid w_{1:1023})p(w_{1023} \mid w_{1:1022})...p(w_2\mid w_1)p(w_1)\\ &\approx p(w_{1024} \mid w_{1023})p(w_{1023} \mid w_{1022})...p(w_2\mid w_{1})p(w_1) \end{align}\]

Again, this is a faulty approximation: \(p(\text{'e'}\mid \text{'h'}, \text{'t'})\) is greater than \(p(\text{'e'}\mid \text{'h'})\), as knowing the first two letters th would give us much stronger confidence in the completion the than just the previous letter h. But it is better than assuming complete independence between tokens.

For each conditional probability term, there’s \(V(V-1)\) parameters as we need one vector for each possible “prior” token. If we share these parameters across timesteps as before12, then there’s \(V(V-1)\) parameters in total to be estimated to be able to compute the conditional probabilities. With \(V=67\), this is 4422 free parameters.


Estimation with maximum likelihood remains similar: To estimate \(p_\theta(b \mid a) = \theta_{a,b}\), we find all the cases where \(a\) happens, and then compute the proportion of them that are followed by \(b\). In practice, we create a counts matrix, and normalize it such that each vector sums to 1:

Note that we add a 0.1 count to each “transition” pair. Without it, if a pair \((a, b)\) doesn’t appear in the training sequence it would have \(p_\theta(b \mid a) = 0\). If it subsequently appeared in the validation sequence, it would also have \(p_\theta(b \mid a) = 0 \Rightarrow -\ln p_\theta(b \mid a) = \infty\) (and in turn, a sequence perplexity of \(\infty\)). Adding it ensures a small probability is assigned to every possible transition.
# Create "initial" counts
counts = np.zeros((len(vocab), len(vocab))) + 0.1

# Compute the counts matrix
for i in range(1, len(train_tokens)):
    prev_token = train_tokens[i-1]
    current_token = train_tokens[i]
    counts[prev_token, current_token] += 1

# Normalize to get proportions that sum to 1.
params = counts / counts.sum(axis=-1, keepdims=True)

When visualized, the matrix of estimated conditional probabilities is as follows. Note that each row sums to 1.

from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(1, 1)
        x=[v.replace("\n", "\\n") for v in vocab],
        y=[v.replace("\n", "\\n") for v in vocab],

fig.update_xaxes(title_text="Next Token", tickmode="linear", tickangle=0)
fig.update_yaxes(title_text="Current Token", tickmode="linear")
                coloraxis={"colorscale": "Greys"})

We can glean patterns from this matrix: for instance, the estimated probability of a newline following a newline \(p_\theta('\backslash\text{n}' \mid '\backslash\text{n}')\) is 0.187.


Evaluating, we have train_ppl=11.87, and valid_ppl=11.95, which more than halves the perplexity values of the unigram model. This makes sense: our estimates for the next token should be better when we account for the previous token.


Sampling is also straightforward: given a starting token \(w_{i-1}\), we sample the next token using the conditional probability vector corresponding to \(p(\cdot \mid W_{i-1} = w_{i-1})\), as follows:

sequence = 'e'

for i in range(1, 30):
    prior_token = char2idx[sequence[i-1]]
    conditional_prob = params[prior_token, :]
    sequence += np.random.choice(vocab, p=conditional_prob)

"eanghitheyen'somillllomp th d "

The text now feels a bit more plausible, but not by much: it’s challenging to form full words when you’re constrained to ignore all tokens other than the one just before.

Bigrams: Neural Networks

The computation we derived above was to represent the joint probability \(p(w_1,...,w_{1024})\) as the product of conditional probabilities \(p(w_{i} \mid w_{i-1})\). The representation we chose was to have each \(p(b \mid a) = \theta_{a, b}\)13. This resulted in a \(67 \times 67\) matrix14.

But we only need an individual parameter for each conditional probability in \(p_\theta\), if we want to be able to update a probability without altering the others. If we’re okay with a constrained parametrization, we can use just about anything: a Fourier series, a sequence of piecewise linear functions, etc. Since neural networks are universal function approximators15 (Hornik, 1991), we could represent \(p_\theta\) using a neural network; that is \(p_\theta(b \mid a) = f_\theta(a)[b]\). In detail:

  • Input: The neural network takes in the token \(a\) as input, and returns a vector of probabilities corresponding to \(p_\theta(\cdot \mid a)\).
  • Output: We then index into this vector with \(b\), getting the probability \(p_\theta(b \mid a)\).

There’s a problem here: while the output (a probability) is continuous, the input \(a\) (a token) is a discrete integer. We can’t use backpropagation here, since we can’t differentiate through \(a\). Or can we?


One approach is to use embedding vectors: we associate a real valued vector with each token in the vocabulary, and use that as an input to the neural network:

This converts the continuous nature of neural networks from a problem to a feature. Recall that with continuous functions, small changes in the input result in small changes in output.

If the next token probabilities (the outputs) for any two tokens \(p\) and \(q\) are similar (that is, \(p(\cdot \mid p) \approx p(\cdot \mid q)\), then during training, the embedding vectors for \(p\) and \(q\) (the inputs) can be optimized to be close to each other (again, similar inputs \(\rightarrow\) similar outputs). Instead of a human defining a hard constraint (such as \(\theta_{p, \cdot} = \theta_{q, \cdot}\) if we were using explicit conditional probability vectors), the network can learn these associations directly from data.

Note that our input pipeline now looks like this: in doing so, we convert a raw string into a sequence of vectors.

\[\text{raw text} \rightarrow \text{tokens} \rightarrow \text{integers} \rightarrow \text{embedding vectors}\]


Training (or learning) really is just a concise way of saying “solving an optimization problem”. Specifically, the following problem:

\[\min_\theta [-\ln L(\theta)]\]

That is, we wish to:

  • find the values of the parameters \(\theta\) of the neural network
  • that minimize the negative log-likelihood of our observed training sequence16
  • or equivalently, maximize the probability of our observed training sequence

Unlike previously, where we “knew” the closed-form, optimal solution (calculate the proportion of the total), this is a considerably more complex parametrization with no closed form solution17. After all, instead of having each \(p_\theta(b \mid a)\) represented by a separate parameter \(\theta_{a, b}\), we now have every parameter being used to compute every conditional probability.

We instead use gradient descent to iteratively estimate values of \(\theta\) with lower negative log-likelihood. We set up a simple neural network with 2 hidden layers as follows:

class LanguageModel(nn.Module):
    num_embeddings: int = 67
    features: int = 16

    def __call__(self, x):
        embed = nn.Embed(self.num_embeddings, self.features)
        # Get the embedding vectors for each input token
        x = embed(x)
        batch_dim, hist_size, features = x.shape
        x = x.reshape(batch_dim, hist_size * features)
        # Apply two hidden layers
        x = nn.Dense(self.features * 4)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.features * 4)(x)
        x = nn.gelu(x)
        # Get logits for next token prediction
        x = nn.Dense(self.features)(x)
        x = embed.attend(x)

        return x

We optimize for 50,000 steps, and compute the perplexity (over the entire train and validation sequences) every 1000 steps. The training notebook is here:

bigram_nn = np.load('assets/values.npz')
fig, ax = plt.subplots(figsize=(9,4))
ax.plot(np.arange(1000, 50001, 1000), bigram_nn["train_ppls"], label="Train (final PPL={:.2f})".format(bigram_nn["train_ppls"][-1]))
ax.plot(np.arange(1000, 50001, 1000), bigram_nn["val_ppls"], label="Validation (final PPL={:.2f})".format(bigram_nn["val_ppls"][-1]))
ax.set_xlim(0, 50000)
ax.set_ylim(12.0, 13.2)


First, let’s look at the outputs. Recall that the network parametrizes \(p_\theta(b \mid a) = f_\theta(a)[b]\). We can compute all the conditional probability vectors \(p_\theta(\cdot \mid a) = f_\theta(a)\) by passing in all 67 unique values of \(a\). Doing so, and concatenating these vectors into a matrix, we have a matrix quite similar to that from the previous section:

fig = make_subplots(1, 2, subplot_titles=("Free", "Constrained (Neural Network)"))
        x=[v.replace("\n", "\\n") for v in vocab],
        y=[v.replace("\n", "\\n") for v in vocab],
        x=[v.replace("\n", "\\n") for v in vocab],
        y=[v.replace("\n", "\\n") for v in vocab],

fig.update_xaxes(title_text="Next Token")
fig.update_yaxes(title_text="Current Token")
fig.update_layout(coloraxis_showscale=False, coloraxis={"colorscale": "Greys"})

The similarity arises because, despite the different parametrizations, both of these models have the same objective: produce values of \(p_\theta(b \mid a)\) for pairs of tokens \(a, b\), that minimizes the negative log-likelihood of the training sequence. If for 100 appearances of e, 10 are followed by o, then the maximum likelihood estimate is \(p_\theta(\text{'o'} \mid \text{'e'}) = \frac{10}{100}\), regardless if its parametrized with individual parameters or a neural network.

Let’s also look at the embedding vectors associated with each token, after training is completed. Since these are 16-dimensional vectors, we project them down to 2 using PCA:

import as px
proj_embeds = bigram_nn["proj_embeds"]
fig = px.scatter(x=proj_embeds[:, 0], y=proj_embeds[:, 1], text=[v.replace("\n", "\\n") for v in vocab], width=700, height=700)
fig.update_traces(textposition='top center')

There’s substantial structure here: the embeddings for the capital letters and lowercase letters form distinct groupings. Conceptually, this makes sense: the next token probabilities (the output) for all capital letter tokens (the input) are likely to have higher probability placed on lowercase letters (than just a lowercase \(\rightarrow\) lowercase transition). Since their outputs share that similarity, their input embeddings should be similar (but not same) too.


The neural network in the previous section had 7,360 free parameters; whereas the matrix of conditional probabilities only had 4,422. The neural net appears to consume more memory and has a higher validation perplexity than the explicit parametrization. But now, instead of conditioning on the previous token, let’s condition on the previous seven:

\[\begin{align} p(x_1, x_2,..., x_{1024}) &= p(x_{1024} \mid x_{1023}, ..., x_{1})p(x_{1023} \mid x_{1022}, ..., x_{1})...\\ &\approx p(x_{1024} \mid x_{1023}, ..., x_{1017})p(x_{1023} \mid x_{1022}, ..., x_{1016})...\\ \end{align}\]

We continue the “shared across timesteps” assumption; that is, the next token probability only depends on the previous 7, and not what numbered token it is in the sequence. Even then, conditioning on 7 tokens we’d have \(67^7(67-1) \approx 10^{14}\) parameters. Since most combinations of the previous 7 tokens are invalid, we could use a sparse representation, but we’d still need a huge number of samples to accurately estimate next token probabilities.

But we could also just parametrize this with a neural network, as \(p_\theta(x_i \mid x_{i-1},...,x_{i-7}) =\) \(f_\theta(x_{i-1},...,x_{i-7})[x_i]\). And neural networks do not have their number of parameters exponential in \(m\). In fact, RNNs and Transformers have a constant number of parameters, independent of \(m\)18.

Training the network from the previous section (but modified to accept 7 inputs vs 1 input, notebook here), and optimizing, we have:

eightgram_nn = np.load('assets/values_8gram.npz')
fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(np.arange(1000, 50001, 1000), eightgram_nn["train_ppls"], label="Train (final PPL={:.2f})".format(eightgram_nn["train_ppls"][-1]))
ax.plot(np.arange(1000, 50001, 1000), eightgram_nn["val_ppls"], label="Validation (final PPL={:.2f})".format(eightgram_nn["val_ppls"][-1]))
ax.set_xlim(0, 50000)

The train perplexity shows a substantial improvement, at 5.80 for this constrained 8-gram model vs 11.87 for the “free” bigram model19. The learned function stores the next-token conditional probability vectors for all \(67^7\) possible sequences of prior tokens. Most of this space of \(67^7\) sequences are nonsensical (like “gsdaksx”). But of the ones in its training set, it is able to exploit shared structure between similar sequences20 to predict similar conditional probabilities.

Constraints as Generalization

Earlier in the widgets example, we saw that having \(p_\theta(a)=\theta\) and \(p_\theta(b)=\theta^2\) effectively tied their probability estimates together. This has a key consequence: If in a larger sample we saw a lower proportion of outcomes \(a\), we couldn’t decrease \(p_\theta(a)\) without also decreasing \(p_\theta(b)\).

This behavior carries over when we use a neural network. Consider the following:

  • We train a bigram model (parametrized by a neural network), with sequences tokenized at the word-level.
  • During training, the model learns an embedding vector for "dog" that is very close to the one for "cat".
  • A training minibatch contains the token pair ("dog", "sad"), and the optimization process attempts to increase \(p_\theta(\text{'sad'} \mid \text{'dog'})\).

Then, \(p_\theta(\text{'sad'} \mid \text{'cat'})\) will also increase as "dog" and "cat" have quite close embedding vectors. In a way, this is a feature: the network is able to exploit the fact both "dog" and "cat" are “similar”, and the former being followed by "sad" means the latter should also likely be followed by "sad"; it is able to generalize this update to another animal.

This generalization behavior is pointed out by Bengio et al. (2003). In their example, a well-trained model should assign the same probability to the sentences “The cat is walking in the bedroom” and “A dog was running in a room”. Even if it sees only one during training, at test time the other will have similar embedding vectors, and in turn similar outputs.

But this generalization only works holds for a well-trained model, with a large enough corpus such that the embeddings have the correct separation. Suppose that in the “true” corpus, our dogs are sad and our cats are happy. The optimization process will only separate the embeddings sufficiently if we have both pairs ("dog", "sad") and ("cat", "happy") in our dataset. Without a requirement to predict different conditional probabilities \(p_\theta(\text{'sad'} \mid \cdot)\) for each, we might have unintended generalization. And this is just one pair of embeddings.

An Emergent Property

Zooming in between the computational21 and representational22 levels here, we see a macro-level property arises: inputs with similar “meaning” have similar next-token predictions. At no point do we hand-specify the representation of each token; simply tuning the embeddings + weights over a sufficiently large corpus results in the “clustering” of tokens23. Referring back to the bigram model, at the micro scale we do not explicitly optimize the embeddings for uppercase letters for “grouping with other uppercase letters”; it emerges organically.

And this emergent property, of token sequences with similar “meaning” having similar inputs, allows a neural network to effectively compress the conditional probability matrix into its weights, taking up less memory and achieving near transfer.


Scaling and Generalization

At their core, large language models (LLMs) are not fundamentally different from the models we look at here: they too are functions that produce a conditional probability for the next token, conditioned on tokens already observed. Their functions have considerably more capacity: the largest neural network we look at has 13,504 params and has a context window of exactly 7 tokens. The largest GPT-3 model from Brown et al. (2020) has 175B parameters, and a context window of up to 2048 tokens.

But in keeping with the spirit of emergent properties (Anderson, 1972), it’s not merely that these language models are larger; they’re also different. While they obey the “near transfer” properties described here (similar input embeddings \(\rightarrow\) similar outputs), they also exhibit a different kind of generalization called in-context learning (ICL) (Min & Xie, 2022). With ICL, the prompt itself can be composed of a few “examples” (e.g. a few translation input-output pairs), that dramatically improve the model’s next-token prediction capabilities. This is hard to explain with just the “similar embedding vectors” property alone.

Our current understanding is that this behavior arises from “induction heads” (Olsson et al., 2022) that emerge organically during training in the Transformer networks (Vaswani et al., 2017) that parametrize these language models. One way to view this is a higher-order emergent property24, that exists at the level of substructures in the network parameters (and not just the input embeddings as previously). ICL also appears to be dependent on both the use of Transformers, and the data distribution of natural language (Chan et al., 2022). In summary, our understanding of language models paramterized by neural networks is necessary, but not sufficient to understand LLMs, as there are emergent properties specific to the use of the Transformer architecture.

Data Distribution

In language modeling, we assume our data is i.i.d. drawn from a true distribution \(p\). That is, the proportion of a snippet of text in an infinite-sized corpus should be \(p(\text{text})\). But in our corpus, what we have can be better described as samples from \(p(\text{text} \mid \text{time})\), with different values of \(\text{time}\).

This has consequences: If the context is The top song on the Billboard Hot 100 is, then \(p_\theta(\cdot \mid \text{context})\) will only be able to return conditional probabilities which minimized negative log-likelihood on the training data. The next token to this context snippet, in the real world changes on a weekly basis!

This problem is particularly exacerbated with fact-heavy completions. Recall that language modeling doesn’t distinguish between the core structure of language, and one-off facts: they’re all just tokens. The “facts” are stored in the weights of the model, no different than the conditional probabilities for the next token of any input sequence. Two potential ways to improve on this:

  • Edit model weights: The recent work ROME (Meng, Bau, et al., 2022) introduced a method called “Causal Tracing” to find the subset of model weights that most influence the “next token” conditional probabilities for the “fact tokens” in a sample sentence, and an efficient editing mechanism for those weights. Follow up work then scales this to editing thousands of “stored facts” at once (Meng, Sen Sharma, et al., 2022).

  • Retrieval-based LMs: Recent models such as RETRO (Borgeaud et al., 2022) and ATLAS (Izacard et al., 2022) augment a large transformer model with access to a database. Other works such as Lazaridou et al. (2022) directly use tools like search engines to add tokens to their input prompt. The overall effect is a model, which can learn to produce factual evidence by copying from retrieved data than by storing facts in its weights. Then, if the fact source is updated, the model simply copies the new information into its generated output.

Non-autoregressive models

The autoregressive, “generate a single token at a time, left to right” method we explore here is just one way to build a language model. A few recent papers that explore other strategies are25:

  • The SUNDAE model introduced in Savinov et al. (2022) trains an autoencoder that iteratively denoises a snippet of text. Of note here is Table 4, where they show results on code-generation experiments, and how SUNDAE can account for context tokens both before and after the token to be generated.

  • The SED model introduced in Strudel et al. (2022) uses diffusion directly on the embedding vectors of tokens. Instead of generating left-to-right, this allows them to iteratively refine an entire snippet of text, at each step.

  • Fill-in-the-middle models explored in Bavarian et al. (2022) restructure the training data from (prefix, middle, suffix) \(\rightarrow\) (prefix, suffix, middle). This allows a standard left-to-right architecture to be repurposed to “infill” text.

Overall, by leveraging shared structure, language models can store conditional probability tables for an exponential number of possible inputs. Many active lines of research continue to explore how these models work, and how they can be made better. And even today, because so many tasks can be rephrased as token prediction problems, they’re beginning to power a large range of practical products, and their impact will only increase as they improve.


Anderson, P. W. (1972). More is different. Science, 177(4047), 393–396.
Bavarian, M., Jun, H., Tezak, N., Schulman, J., McLeavey, C., Tworek, J., & Chen, M. (2022). Efficient training of language models to fill in the middle. arXiv.
Bengio, Y., Ducharme, R., Vincent, P., & Janvin, C. (2003). A neural probabilistic language model. J. Mach. Learn. Res., 3(null), 1137–1155.
Borgeaud, S., Mensch, A., Hoffmann, J., Cai, T., Rutherford, E., Millican, K., Van Den Driessche, G. B., Lespiau, J.-B., Damoc, B., Clark, A., De Las Casas, D., Guy, A., Menick, J., Ring, R., Hennigan, T., Huang, S., Maggiore, L., Jones, C., Cassirer, A., … Sifre, L. (2022). Improving language models by retrieving from trillions of tokens. In K. Chaudhuri, S. Jegelka, L. Song, C. Szepesvari, G. Niu, & S. Sabato (Eds.), Proceedings of the 39th international conference on machine learning (Vol. 162, pp. 2206–2240). PMLR.
Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., … Amodei, D. (2020). Language models are few-shot learners. arXiv.
Chan, S. C. Y., Santoro, A., Lampinen, A. K., Wang, J. X., Singh, A., Richemond, P. H., McClelland, J., & Hill, F. (2022). Data distributional properties drive emergent in-context learning in transformers. arXiv.
Hornik, K. (1991). Approximation capabilities of multilayer feedforward networks. Neural Networks, 4(2), 251–257.
Izacard, G., Lewis, P., Lomeli, M., Hosseini, L., Petroni, F., Schick, T., Dwivedi-Yu, J., Joulin, A., Riedel, S., & Grave, E. (2022). Few-shot learning with retrieval augmented language models. arXiv.
Karpathy, A. (2015). The unreasonable effectiveness of recurrent neural networks.
Kudo, T., & Richardson, J. (2018). SentencePiece: A simple and language independent subword tokenizer and detokenizer for neural text processing. Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, 66–71.
Lazaridou, A., Gribovskaya, E., Stokowiec, W., & Grigorev, N. (2022). Internet-augmented language models through few-shot prompting for open-domain question answering. arXiv.
Marr, D., & Poggio, T. (1976). From understanding computation to understanding neural circuitry. Massachusetts Institute of Technology.
Meng, K., Bau, D., Andonian, A., & Belinkov, Y. (2022). Locating and editing factual associations in GPT. Advances in Neural Information Processing Systems, 35.
Meng, K., Sen Sharma, A., Andonian, A., Belinkov, Y., & Bau, D. (2022). Mass editing memory in a transformer. arXiv Preprint arXiv:2210.07229.
Mikolov, T., Chen, K., Corrado, G., & Dean, J. (2013). Efficient estimation of word representations in vector space. arXiv.
Min, S., & Xie, S. M. (2022). How does in-context learning work? A framework for understanding the differences from traditional supervised learning.
Olsson, C., Elhage, N., Nanda, N., Joseph, N., DasSarma, N., Henighan, T., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Johnston, S., Jones, A., Kernion, J., Lovitt, L., … Olah, C. (2022). In-context learning and induction heads. Transformer Circuits Thread.
Savinov, N., Chung, J., Binkowski, M., Elsen, E., & Oord, A. van den. (2022). Step-unrolled denoising autoencoders for text generation. International Conference on Learning Representations.
Schunk, D. H. (2011). Learning theories (6th ed.). Pearson.
Strudel, R., Tallec, C., Altché, F., Du, Y., Ganin, Y., Mensch, A., Grathwohl, W., Savinov, N., Dieleman, S., Sifre, L., & Leblond, R. (2022). Self-conditioned embedding diffusion for text generation. arXiv.
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, & R. Garnett (Eds.), Advances in neural information processing systems (Vol. 30). Curran Associates, Inc.


  1. Possible meaning “non-zero probability”↩︎

  2. compared to other values of \(\theta\)↩︎

  3. For example, \(p(a) = 0.03\)↩︎

  4. We do keep the baseline constraint that the probability of all three outcomes must sum to 1.↩︎

  5. That is, the function that stores the probabilities of specific outcomes for both \(X_1\) and \(X_2\), compared to all other pairs of outcomes.↩︎

  6. is a character too, so there’s 8 in total.↩︎

  7. This carries on the coin logic: the coin doesn’t care if it’s the 1st or the 100th toss, the probability for the outcome heads remains the same.

    Likewise, \(X_{562}\) having the same distribution as \(X_1\) means the probability for an outcome (e.g. \(p(\text{'t'})\)) is the same at both timesteps.↩︎

  8. Note an implementation detail: we add a 0.1 as a “starting count” for each of the 67 outcomes. We’ll get to this when we cover bigrams next.↩︎

  9. For instance, if ? appears 10x more in the training set compared to the overall corpus, the maximum likelihood estimate \(p_\theta(\text{'?'})\) will be 10x larger than \(p(\text{'?'})\) .↩︎

  10. Note that we begin computing perplexity starting at the 8th token, for both the training and validation sequences. Later in the article we’ll build a model estimating probabilities conditioned on the 7 previous tokens; adjusting now means the perplexity comparisons across all models remain comparable.↩︎

  11. 1 is a computational minimum. In practice, language itself has a minimum perplexity larger than 1, and even the best model we’ll ever build won’t go lower.

    Analogously, even with a perfect estimate of \(\theta=0.6\) for the biased coin earlier, we cannot perfectly predict the sequence of heads and tails in a series of tosses; there is randomness inherent to the process itself.↩︎

  12. That is, if \(x_a = x_b\), and \(x_{a-1} = x_{b-1}\), then \(p(x_a \mid x_{a-1}) =\) \(p(x_b \mid x_{b-1})\).

    This is a sensible assumption, as \(p(\text{'e'}\mid \text{'h'})\) should be the same regardless if e is the 2nd or 285th token: all the information needed to predict it is (assumedly) in the prior token h.↩︎

  13. Subject to \(\sum_b \theta_{a, b} = 1\)↩︎

  14. We could have technically used a \(67 \times 66\) matrix. For each of the 67 conditional probability vectors, once we know the first 66 probabilities, we can compute the 67th by subtracting the sum from 1. Here it’s just easier to have the extra column.↩︎

  15. In the case of infinite width; but even in finite cases, given sufficient width and depth they’re effective.↩︎

  16. Note that in practice, we’re solving for \(\min_\theta [-\ln L(\theta) + R(\theta)]\), where \(R(\theta)\) is some regularization term (such as the squared sum of the weights \(\lambda \sum_i \theta_i^2\)).↩︎

  17. And many local optima, not just one global optima.↩︎

  18. The naive feedforward network we use here has its number of parameters increase linearly with \(m\). Here, moving from 1 to 7 inputs, we have an increase from 7,360 to 13,504 params.↩︎

  19. And 11.87 is the lower bound on training perplexity possible with the bigram assumption, since the probability estimates for each conditional probability were at the global minima.↩︎

  20. Just as the bigram model did by placing all capital letters close to each other in input space.↩︎

  21. minimize negative log-likelihood↩︎

  22. the embedding vectors + network weights↩︎

  23. And this property was what drove key earlier methods such as Word2Vec (Mikolov et al., 2013).↩︎

  24. Existing in between the computation and representation levels.↩︎

  25. And I note this is hardly exhaustive.↩︎