LoRA (Hu et al., 2021) is a now popular alternative to the full finetuning of a Large Language Models (LLMs): instead of tuning the billions of weights of the full model, we add small “adapter” weight matrices that modify the original weight matrices, and tune those instead.
This blogpost dives deeper into a curious behavior: although LoRA is commonly seen an drop-in for full finetuning, its interaction with weight decay means it solves a different optimization problem than full finetuning. Namely, one where the solution weights are regularized towards the frozen base model
This means, given increasingly more resources (even equalling that of full finetuning), LoRA does not increasingly better approximate full finetuning, because its objective function is implicitly different to that of full finetuning. This, depending on use case can either be seen as a bug or a feature, but is something practitioners should explicitly account for.
Recap: Finetuning
With LLMs, we typically finetune an initial model (that is “good” on a wide range of text-to-text tasks) to boost performance on a specific task of interest (e.g. generating database queries from natural language). We do this in a two-step process:
- First, creating a finetuning training dataset
, which contain pairs of inputs and targets .1 - Optimize the weights of the initial model such that our finetuning training dataset
becomes more “probable”. The idea here is that a model that is more likely to generate the correct answers on ’s from our training set, will generalize and also be more likely to generate ’s on new ’s.
Full Finetuning
Full finetuning means we tune all the weights of the model. For a model such as GPT-3 175B (Brown et al., 2020), this means giving our optimization algorithm 175 billion numbers it can “dial” up and down as needed to make our finetuning training data more “probable”. Let’s dig a bit deeper, and more concretely define what we mean by weights here.
Each layer in a Transformer is primarily made of two components: a multihead attention network, followed by a feedforward network. This means the bulk of the “weights” that make up each layer are stored in six matrices2, as shown.
In full finetuning, every single weight in
Now, directly doing gradient descent this way would quickly lead to overfitting4, so we usually regularize the problem. With LLMs, the regularization tool of choice is usually weight decay. Specifically, when using vanilla SGD5, weight decay is equivalent to having a term in the loss equal to the squared sum of the weights:
Hence, the overall objective now is as follows (where
Differentiating this to objective to get the gradient, we notice the gradient update has two distinct terms6: the first corresponding to the minimizing the negative log likelihood as before, and a new second term
Which means the regularized problem now looks like:
In summary, adding a squared sum of weights loss is equivalent to subtracting a scaled version of each weight at each gradient descent step. This shifts the minima towards where the weights are closer to
Full finetuning is highly flexible, but also extremely memory intensive: you generally need at least 3x the memory8 required for the model itself, to account for its gradients and optimizer state. This was not an issue when models were
LoRA finetuning
LoRA (Low Rank Adapter) finetuning takes a different approach: instead of tuning the massive weight matrices of an LLM directly, we use a pair of small adapter matrices for each weight matrix we want to tune, of the following form:
That is, for each initial, frozen weight
With LoRA with rank
This is less than 0.1% of the original number of parameters; the added overhead of storing 3 variants of these values (weights, gradients and optimizer states) is negligible compared to the memory used by the model itself.
Moreover, since the initial weights are “shared” across all the finetuning runs, at inference time we only need to load one copy of the initial model to be shared across many finetuned versions, with inference for each task using their own per-task adapter matrices. This makes having a “per-task” tuned LLM in an application not only viable, but easy.
The Interaction
Now that we’ve covered what LoRA is, we can begin to discuss how it interacts with weight decay to produce a feature/bug. Since
Let’s contrast this with the formulation in full finetuning:
- In full finetuning, we have
, in that the weight decays to 0 directly. - However, in LoRA, because
and decay to 0, in effect we have instead.
This means LoRA solutions are biased towards the original frozen weight matrices, unlike in full finetuning, where they’re biased towards zero. And this behavior does not go away with increasing the LoRA rank
A fix
If we wanted the full adapted matrix to go towards zero (as would happen in full finetuning), we’d need a regularization term where the entire adapted weight matrix goes to zero, as follows:
This is actually straightforward to derive, and yields a pair of update equations that can be implemented much like standard weight decay. First, start at the core definition of weight decay, which involves calculating the gradient of the weight w.r.t. the regularization term:
Second, compute the gradient of
Inserting back into the definition of weight decay, we get the following concrete update equations for
In code
This is what the standard formulation of weight decay in the Optax (Babuschkin et al., 2020) library looks like. It’s quite clean: add a weight_decay
(p
to its current update g
10.
To modify this to implement the math we just described above takes some of extra code, mostly in extracting the W_init
, A
and B
matrices11. The core logic is just the two lines 18 and 20.
def update_fn(updates, state, params):
def per_param_update_fn(path, update, param):
# Get the params dict for the layer as a whole.
param_name = path[-1].key
# If current parameter is an adapter matrix.
if param_name in ['kernelA', 'kernelB']:
layer_params = params
for dict_key in path[:-1]:
layer_params = layer_params[dict_key.key]
# Extract the initial weight matrix and adapter matrices.
W_init = layer_params['kernel']
A = layer_params['kernelA']
B = layer_params['kernelB']
# Compute the corrected decay term.
if param_name == 'kernelA':
decay_term = (W_init + A@B)@B.T
else:
decay_term = A.T@(W_init + A@B)
# If current parameter is *not* an adapter matrix, use
# default version of weight decay.
else:
decay_term = param
return update + weight_decay * decay_term
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
updates = jax.tree_util.tree_map_with_path(
per_param_update_fn, updates, params)
return updates, state
Conclusion
In summary, LoRA has a different implicit objective than full finetuning, but it’s also easy to correct if desired. That’s it, really!
To my knowledge, there isn’t literature documenting this interaction of LoRA with weight decay in depth. Conjecturing purely from first principles12, I’d argue the default behavior is both a feature and a bug, depending on the amount of data - when there’s a very few number of training points, it is a feature because it regularizes the updated model to stay close to the initial, “generally-capable” one. However, it’s a bug when given large amounts of data, as the optimization process is less capable of straying too far from the base weights, even if it would aid end-task performance.
That said, as neat as the math is, empirical results are the only truth here. With so many free parameters, it may well turn out to be in practice there are solutions just as good (when regularized to be close to
Appendix A: Momentum and Weight Decay
One odd thing you’ve likely noticed is that I spent a substantial amount of time explicitly working out the gradient for the regularizer term
The AdamW paper13 is a solid, in-depth read to understand why this is the case, but in brief: to do weight decay we want to subtract away a scaled version of the parameter’s value at the current timestep. However, adding an
The way modern optimization libraries such as Optax implement AdamW is by first implementing Adam’s transformation of the gradient as a seperate subroutine
scale_by_adam
here that does exactly this.- takes in the NLL loss gradient
- as well as past optimizer states
- return a “transformed” gradient, an update
, that is, .
From there on out, the weight decay looks just like it did before, but swapping in
Which means, a version of our corrected (decays to 0) LoRA update that is compatible with AdamW looks like:
The code snippet above (implementing the decay to 0 LoRA) is actually already compatible with AdamW in Optax. This very nice behavior comes mostly from free because of the fact AdamW in Optax is already a decomposed chain of three operators (
All we’d need to do is create a new optimizer, where we swap in the transform.add_decayed_weights
with our custom version, and we’d be set.
References
Footnotes
In the database query example, the
’s can be strings in English, and the ’s are then strings corresponding to the query translated from English into the query schema.↩︎Note that if you use a GLU-variant activation (Shazeer, 2020), then you add in a 7th “gating” weight matrix.↩︎
This is the precise mathematical definition of what we just described: a function whose minimization makes our finetuning training data
“more likely” to be generated.↩︎In that the weights would all be optimized to perfectly repeat
for any in the finetuning training set, at the expense of performing much worse on any not in the training set.↩︎Stochastic Gradient Descent, a.k.a. the core workhorse of deep learning. In practice we use more sophisticated momentum-based methods, whose impact is described in Appendix A.↩︎
This directly stems from the fact that the gradient of a sum (here, the two terms are NLL and regularization) equals the sum of the gradients (of each term).↩︎
To reason why this is true, notice that the larger the weight, the “more” of it is subtracted away from itself.↩︎
Assuming you’re using an optimizer with some form of momentum (vanilla SGD doesn’t need an optimizer state). It goes up to 4x for Adam, as it has two states: an exponential moving average of both the gradient means, and the gradient-squared means.↩︎
We’re dropping the
subscripts as the derivation is identical for all the weight matrices.↩︎We add terms to the update, as the subtraction of the update happens at the very end.↩︎
This exact formulation assumes the adapters are defined inside the same layer as the original matrix; that is the params dict looks like
{'params': {'kernel': ..., 'kernelA': ..., 'kernelB': ...}}
. The actual implementation will depend on how the LoRA adapters have been defined (even though the underlying math will remain the same).↩︎Which, in classic deep learning fashion could turn out to be wholly incorrect.↩︎
Which pointed out this non-equivalency, and produced a version of Adam that “decoupled” weight decay.↩︎
Citation
@online{shafkat2023,
author = {Shafkat, Irhum},
title = {LoRA and {Weight} {Decay}},
date = {2023-09-27},
url = {https://irhum.github.io/blog/lorawd/},
langid = {en}
}