Batch Normalization and Layer Normalization
Transformers in Deep Learning - Part 8
Sibling notes: Positional Encoding in Transformers - The Deep Dive · Residual Connections in Transformers - The Deep Dive
Training deep neural networks is difficult because activations and gradients can become unstable as information passes through many layers.
Normalization helps by keeping values in a controlled range.
In Transformers, the most important normalization method is layer normalization.
To understand why, it helps to first understand normalization in general, then batch normalization, then layer normalization.
1. Why Normalization Matters
Suppose a model predicts house prices from two features:
| Feature | Typical scale |
|---|---|
| square feet | 1000 to 5000 |
| number of rooms | 1 to 10 |
These features live on very different scales.
If the model receives them directly, the large-scale feature can dominate optimization. The loss surface can become stretched, which makes gradient descent inefficient.
flowchart LR
U["unnormalized features"] --> E["elongated loss surface"]
E --> S["small learning rate needed"]
S --> T["slower training"]
Normalization rescales values so training is more stable.
The standard normalization formula is:
where:
| Symbol | Meaning |
|---|---|
| original value | |
| mean | |
| standard deviation | |
| normalized value |
After normalization, values are centered around 0 and usually have standard deviation close to 1.
2. Normalization and Training Stability
Normalization helps because it:
- reduces scale imbalance
- makes gradients more stable
- allows larger or more reliable learning rates
- helps activations avoid extreme values
- improves convergence speed
The goal is not just prettier data.
The goal is a smoother optimization problem.
flowchart LR
N["normalized values"] --> C["better-conditioned loss surface"]
C --> G["more stable gradients"]
G --> F["faster convergence"]
3. Internal Distribution Shift
In a deep network, the input to each layer is the output of the previous layer.
During training, weights change after every update.
So the distribution of activations entering later layers also changes.
This means an intermediate layer may receive inputs with different scales and distributions throughout training.
That instability makes learning harder.
Normalization inside the network helps reduce this problem by keeping intermediate activations controlled.
4. Batch Normalization
Batch normalization normalizes each feature using statistics computed across a mini-batch.
Imagine a batch with:
- 3 examples
- 5 features per example
The activation matrix is:
Batch normalization works feature by feature.
For feature 1, it computes the mean and standard deviation across the 3 batch examples.
For feature 2, it does the same.
And so on.
flowchart TD
B["batch examples"] --> F1["normalize feature 1 across batch"]
B --> F2["normalize feature 2 across batch"]
B --> F3["normalize feature 3 across batch"]
So batch norm asks:
For this feature, how does each example compare to the batch?
Mathematically:
where and are computed across the batch for feature .
5. Scale and Shift
After normalization, models usually apply a learned scale and shift:
where:
| Parameter | Role |
|---|---|
| learned scale | |
| learned shift |
At first, this can look like it cancels normalization.
But it gives the model flexibility.
Some layers may benefit from perfectly normalized values. Other layers may benefit from a different scale or offset.
So normalization stabilizes the data, while and let the model learn the best final shape.
6. Why Batch Normalization Is Awkward for Transformers
Transformers usually process tensors shaped like:
For example:
This means:
- batch size = 2
- sequence length = 3
- embedding dimension = 512
Batch normalization would compute statistics across the batch dimension, and often across token positions depending on implementation.
That is awkward for language models because:
- sentences have variable lengths
- padding tokens can distort batch statistics if not handled carefully
- batch statistics change with batch composition
- generation often happens one token at a time
- the same token should be normalized consistently regardless of which other examples appear in the batch
Batch normalization is not impossible in sequence models, but it is not the natural default for Transformers.
Layer normalization fits the Transformer shape better.
7. Layer Normalization
Layer normalization normalizes across the features of one data point.
In Transformers, that usually means:
Normalize each token representation across its hidden dimensions.
If one token representation is:
layer norm computes:
Then:
Finally:
Layer norm asks:
For this token, how does each feature compare to the other features in the same token vector?
8. Batch Norm vs Layer Norm
The main difference is the direction of normalization.
| Method | Normalizes across | Depends on batch? | Common fit |
|---|---|---|---|
| Batch normalization | examples for each feature | yes | CNNs and many feed-forward networks |
| Layer normalization | features within each example/token | no | Transformers and language models |
For a Transformer tensor:
layer normalization is applied independently to each token vector:
So each token gets its own mean and variance across the hidden dimension.
flowchart TD
X["token vector: 512 features"] --> MU["mean across 512 features"]
X --> VAR["variance across 512 features"]
MU --> LN["normalize token"]
VAR --> LN
LN --> SS["scale and shift with gamma, beta"]
9. Layer Norm in Transformers
A Transformer block contains sublayers such as:
- multi-head attention
- feed-forward network
Layer normalization is used around these sublayers to stabilize training.
In the original Transformer, the structure was commonly written as:
This is often called post-norm, because layer norm happens after the residual addition.
Many modern Transformers use pre-norm:
Both designs use the same core idea:
Keep activations stable while allowing deep stacks of Transformer blocks to train.
10. Why Layer Norm Helps Transformers
Layer normalization helps because it:
- does not depend on other examples in the batch
- works naturally with variable-length sequences
- works during autoregressive generation
- stabilizes hidden states across deep layers
- reduces the risk of exploding or vanishing activations
- pairs well with residual connections
It is a small operation, but it has a large effect on trainability.
Without normalization, very deep Transformer stacks become much harder to optimize.
11. Common Confusions
No. Deep networks also need stable intermediate activations.
No. Batch norm normalizes across examples for each feature. Layer norm normalizes across features for each example or token.
Not exactly. It can be engineered, but it is awkward for variable-length sequence modeling and batch-dependent generation. Layer norm is the standard choice.
No. They let the model choose the useful scale and offset after stabilization.
Padding is one practical issue, but the deeper reason is that layer norm is batch-independent and fits token-wise hidden states naturally.
12. One-Line Interview Answer
Batch normalization normalizes each feature using statistics across a mini-batch, while layer normalization normalizes each example or token across its hidden features. Transformers use layer normalization because it is independent of batch composition, works naturally with variable-length sequences and generation, and stabilizes deep Transformer training.
13. Final Intuition
Batch normalization compares one feature across many examples.
Layer normalization compares many features inside one token representation.
For language models, the second view is usually the better fit.
Each token can be stabilized on its own, no matter what other sentences are in the batch.
That is why layer normalization became a core part of Transformer architecture.