Welcome to the next chapter! This is a really important topic for LLMs and even DL in general. Batchnorm, alongside with other innovative normalization techniques, is a must-know in Deep Learning. Maybe we should jump right in to the topic, as there is a lot to talk about.
A note here, this is still the part from the Makemore series of Andrej Karpathy, so shout out to him with a big respect.
Also, I assume that you already have the code for the previous model, because we won't code everything all over again or make huge modifications, we will just take out the previous one and make some small changes. For the sake of convenience, I won't show you the previous code (it is in the previous chapters, I'm lazy), but rather small parts of it that need to be modified.
Problems with the Previous Model
When scrutinizing a newly-invented technique, we should really think about the problems that it solves. Normalization has been linked with a wide range of problems regarding neural network, so let's see why people need to normalize things when playing with neural network, or a similar question, what are the problems if we let everything behaves freely in the net?.
The Logits
We are creating the weights and biases in the output like no one in the neural net game, it's just so bad. First let's look at our first iteration and see the loss that we get:
0/ 200000: 27.8817
That is TOO FAR away from the minimized loss, and to make it more alarming, we should look at the graph of the loss in each epoch:
Look familiar? That is a hockey-stick graph, but for the loss, it shows a disaster: We have a way-too-high loss right from the beginning, and the model struggled at thousands of iterations just to get that unfathomable loss down. Can we prevent that? Right from the beginning?
Well, actually we can get a little bit of Maths into the initialization process, to put an ease to our model in its very first step, and maybe we can make it even more efficient and arrive at a better loss. Researchers think like that too, and oftentimes, in their process of initializing the parameters, they often have a kind of Expectation.
What do we expect from the first step of the model? In the most naive way? Well, we may think that the model treats each character equally, right? Every character has the same probability, and then we can drive up or drive down some of them to make the model stronger. As there are 27 of them, in the output layer each character is supposedly assigned with the probability of 1/27. Now let's see what is the loss given that expectation, which is just the negative log of 1/27 (no magic number or 'aha' moment here, you can come up with that in your mind, it's fairly straightforward):
>>>-torch.tensor(1/27.0).log()
tensor(3.2958)
It's a GREAT loss to start! So why did our model start so poorly? Think about it.
It turns out that, in the initialization process, when we randomly sample from the Gaussian distribution, there are definitely some cases that some characters are given a higher value than others while they are not deserved to be so. So it's not a good idea to let that bias to affect us in the first step, and actually, we want everything to be equally judged so that it produces a smaller initial loss.
Let's see a simple example:
logits = torch.tensor([0.0,0.0,0.0,0.0])
probs = torch.softmax(logits,dim=0)
loss = -probs[2].log()
probs,loss
(tensor([0.2500, 0.2500, 0.2500, 0.2500]), tensor(1.3863))
It's good, but what if we introduce some extreme values?
logits = torch.tensor([-10.0,0.0,0.0,15.0])
probs = torch.softmax(logits,dim=0)
loss = -probs[2].log()
probs,loss
(tensor([1.3888e-11, 3.0590e-07, 3.0590e-07, 1.0000e+00]), tensor(15.0000))
The loss increases significantly! So the problem here are the extreme values, or in another words (we should really translate our language into math-like things so that we can modify the code), we don't want some values to be too far away from the average. To put it more precisely, we want the standard deviation of the statistics to be really small, like 1 for example.
Also, the mean should be set around zero, it is useful later and we will discuss about that. Actually, if you read enough ML/DL books or Statistics books, you might be familiar with the: Mean 0, Std 1 kind of stuff. Well, for some reasons, researchers love that thing, and it has a beautiful name which is the Unit Gaussian Distribution. That's exactly what we do when we normalize our data, it's really about making the data like the Unit Gaussian.
Now let's turn back too our code, what will we do with the weights and biases to make the values less extreme?
First, the bias should be set to zero, as we discuss earlier. Second, the weight should be smaller, because we don't want some numbers to 'explode' after matrix multiplication, we want to keep the numbers as small as possible. To this end, we multiply 0.1 to the W1 matrix, makes sense.
You may wonder: Why don't we set the weight matrix to 0? Actually, even though we said that Everything should be equal, but we won't do that in an extreme way.
"It's not good to go extreme, my friend"
Researcher didn't said that, but they meant that by putting it like: We won't set the matrices at initialization to all zeros because we want some entropy to our loss.
I won't go deep into this, just think of it as a kind of regularization on the other end.
Now let's modify our code
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((vocab_size, n_embd), generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g)
b1 = torch.randn(n_hidden, generator=g)
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0
Notice the *0.01 and the *0 added in the two last matrices (those of the output layer). We shall see our loss now:
0/ 200000: 3.3221
Perfect! Now look at the graph

No hockey stick! And maybe our model even performs better, as it doesn't have to waste thousands of initial iterations. But we need to move to another problem in our model.
Saturated Tanh & Vanishing Gradient
Logits are not the only thing that went wrong in the process. Let's look at the hidden states, which is simply the tanh of the embedded chars:
plt.hist(h.view(-1).tolist(),50)
Woah what happened? The values are kind of polarized into two ends: -1.00 and 1.00. We know that something has gone wrong here, and as some brilliant detectives, we traced one step back, which is the hidden states before the tanh.
And for the last piece in the picture, let me reminds you of the tanh() function:
tanh() is a squashing function, it squashes and caps every input into the range of 1.0 and -1.0. Seeing the graph of the hidden states before activation, we know exactly what is the problem here: Before activations, there are large numbers, and those numbers enter the tanh() and later stay at one of the two ends. This process is called Saturation, where values are mostly pushed towards the extreme ends of the function.
And how does that affect our model exactly? Thinking about the gradient. The gradient of the tanh() is actually 1-t^2, with t being the value of the output of tanh(), so with every output close to 1 and -1, the gradient will be close to zero! In fact, after backpropagating through many layers of neurons, the gradient can really shrink to zero. That is called Gradient Vanishing, and when no gradients are passed down, things get updated really slowly, hence driving down the efficiency of the model.
Now let's see what I mean here, just look at the proportion of the hidden states that have the absolute value larger than 0.99:
plt.figure(figsize = (20,10))
plt.imshow(h.abs()>0.99,cmap = 'gray',interpolation = 'nearest')
The white part is where the gradients zero out and the neurons are dead. You can see that our model is largely saturated.
So far we've been chilling with our model because it is just a small one. But oftentimes we will have to deal with neural nets with tons of layers, and the problem can accumulate really fast.
Now let's fix this: We should bring everything close to zero, and make the variance small so that the tanh() does not push anything to the ends (you can see why this is a good choice by again looking at the graph of the tanh()). In doing so, we're normalizing the hidden states before activation:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((vocab_size, n_embd), generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * 0.01
b1 = torch.randn(n_hidden, generator=g) * 0.01
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0
parameters = [C, W1, W2, b2]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
p.requires_grad = True
That is the same with the thing we did last time, but in the process, we multiply the b1 by 0.01, not 0, this is, again, an attempt to "introduce some entropy to the model". Now let's run and see our hidden states, after and before activation respectively.
# And we should also see if there are some dead neurons
plt.figure(figsize = (20,10))
plt.imshow(h.abs()>0.99,cmap = 'gray',interpolation = 'nearest')

Everything looks great! The tanh() doesn't behave badly like before, as we successfully normalize the input, and also, we don't see any saturated values in the last plot (actually we will have to change the scalar of the weight up a little bit so there are few dead neurons, another attempt to prevent going too extreme).
About the loss, we don't expect really much in this because we're just dealing with a fairly small MLP.
Initialization
A small example
You might question me "How do you know those magic 0.1 numbers that are multiplied with the weights?". Actually that is just a random small value (and yet it worked!), but there are numerous research around this, really interesting ones.
Let's consider a small set of data:
import torch
import matplotlib.pyplot as plt
x = torch.randn(1000,10)
w = torch.randn(10,200)
y = x @ w
print(x.mean(),x.std())
print(y.mean(),y.std())
plt.figure(figsize = (20,10))
plt.subplot(121)
plt.hist(x.view(-1).tolist(),50,density = True)
plt.subplot(122)
plt.hist(y.view(-1).tolist(),50,density = True)
tensor(-0.0027) tensor(1.0000)
tensor(0.0031) tensor(3.1318)
You can see that while the x is Unit Gaussian, everything messes up after multiplying with w. Specifically, the mean is good, it's around zero, but the standard deviation spells disaster.
Remember, our goal is to make everything "Unit-Gaussian-like", so the real task here is to find the value of the matrix w such that the Unit Gaussian is preserved after matrix multiplication. Now let's try this thing: Divide the weight by sqrt(fanin), in which the fanin, by convention, is the number of input.
x = torch.randn(1000,10)
# Divide the weights by 1/sqrt(fanin)
w = torch.randn(10,200) / 10 ** 0.5
y = x @ w
...
The output looks beautiful! But how?
Kaiming_Init
One of the initialization strategies proposed is the Kaiming Initialization, or the He Initilization. It is also a method in the Pytorch library, named init.kaiming_normal. This method has the nonlinearity argument, which specifies the activation function that we would like to use. Note that each activation function has their own gain, which is simply a scalar to combat the "squish" that the activation function makes. The gain for our tanh() is approximated at 5/3. Also, there is the argument called fan_mode, we can pass in whether 'fan_in' or 'fan_out', depending on the flow, forward or backward. In our case, we use fan_mode = 'fan_in', which is the default value. The standard deviation is calculated following the formula : "std = gain / sqrt(fan_mode)".
You can see the whole paper here.
Now let's change our W1 matrix:
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / (n_embd * block_size)**0.5
And that's basically it ! We solved some of the problems regarding neural nets. Note that all of which are just issues involving driving down the weights and even zeroing out the bias, in order to make the data more Unit Gaussian, which is crucial to maintain the stability of the model.
Batch Normalization
Now let's talk about the topic today, right? We went such a long way to reach here. Actually this is one of the modern innovations that solve all of our previous problems, and the idea behind is fairly simple.
If we want the hidden states to be Unit Gaussian, why bothering about initializing and tweaking the weights and bias matrices? Why don't we just SIMPLY NORMALIZE them? And that's how a brilliant idea came in to place.
You can see the paper here.
And this is the implementation of it:
It's just the procedure of normalization: Subtract from the mean, and divide by the standard deviation (which is the square root of the variance). Note that in the diving part (line 3), we have the term epsilon in the denominator: That term is added to prevent the case of zero variance, in which we have to divide things by zero. Usually we have epsilon = 1e-5.
Another important thing to note here is the last line, we have two parameters called gamma and beta, which is for "scale and shift". What is it for?
Well, we don't really want our data to be exactly Gaussian, we would like the neural net to move around a bit, to make it more diffuse. Those parameters are trained in the training part, so the model decides how to move the data around. May be it's just some kind of flexibility added to the model.
# We call those parameters "gain" and "bias", "gain"
# is like the "weight", but it's not a matrix
# In fact, both are 1-dim row vectors
# In the initialize part, we expect to keep things unchanged
# So everything is multiplied by 1 and added with zero
bngain = torch.ones((1,n_hidden))
bnbias = torch.zeros((1,n_hidden))
Also, remember to include these things in the parameters, as those are trainable:
# Include in the parameters
parameters = [C, W1, W2, b2, bngain, bnbias]
for p in parameters:
p.requires_grad = True
Now let's modify our model a bit: Before applying the tanh() function, we normalize the hidden states:
# same optimization as last time
max_steps = 200000
batch_size = 32
lossi = []
for i in range(max_steps):
# minibatch construct
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y
# forward pass
emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer
hpreact = embcat @ W1 + b1 # hidden layer pre-activation
# Batch Norm
hpreact = bngain * (hpreact - hpreact.mean(0, keepdim=True)) / hpreact.std(0, keepdim=True) + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
logits = h @ W2 + b2 # output layer
loss = F.cross_entropy(logits, Yb) # loss function
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# update
lr = 0.1 if i < 100000 else 0.01 # step learning rate decay
for p in parameters:
p.data += -lr * p.grad
# track stats
if i % 10000 == 0: # print every once in a while
print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
lossi.append(loss.log10().item())
Notice the Batch Norm layer in the middle of the Linear and the Non-linearity layers.
Batch Norm should also be used in the test set, or in the evaluation mode.
@torch.no_grad()
def split_loss(split):
x,y = {
'train': (Xtr, Ytr),
'val': (Xdev, Ydev),
'test': (Xte, Yte),
}[split]
emb = C[x] # (N, block_size, n_embd)
embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
hpreact = embcat @ W1 + b1
# Adding the BatchNorm layer here
hpreact = bngain * (hpreact - hpreact.mean(0, keepdim=True)) / hpreact.std(0, keepdim=True) + bnbias
h = torch.tanh(hpreact) # (N, n_hidden)
logits = h @ W2 + b2 # (N, vocab_size)
loss = F.cross_entropy(logits, y)
print(split, loss.item())
split_loss('train')
split_loss('val')
In fact, when we're dealing with numerous hidden layers in the neural net, we should sprinkle BatchNorm across the whole net, right after each layer.
So far we've been praising this technique really much. But does it live up to the hype?
Some issues with BatchNorm
First just look at the name of the technique: BatchNorm, and yes, we normalize our data, but we do that in batches. What is the problem with that?
Our batches are generated in a purely random way, and then we pass our data in groups. In our previous model, it still works because when we pass the data in, we don't do anything with the group that each instance belongs to, but rather we treat each one independently, so eventhough we're dealing with batches, we still hold the independence property of data. But look at what we do when applying BatchNorm: We normalize the data with respect to the their batches, so actually the value of data is dependent on others in the same group.
That is rather an ugly thing, as it is super unnatural to work with groups like that, and we're even generating the groups randomly. But, a big moment here, it turns out that in practice, this is actually a regularization technique.
By feeding in random batches of instances, the model gets affected largely when new batches come, which is considered as a kind of noise. Hence, the model becomes more robust to tiny shifts, as well as it will learn features rather than memorize, and that means regularization. (Actually I'm not really clear about this, so you can search on Google or ask ChatGPT for better explanation)
Through time, there's been numerous normalizing techniques (like the LayerNorm) that have been developed to replace this BatchNorm, mainly because people don't like the "grouping bug" of this thing. But let's just stick with this amazing normalization technique in this blog post.
Problems with Inference/ Evaluation
This is a hard-to-grasp concept, and it took me hours to understand. I will do my best to give you a good sense of what this is.
Remember in the test mode, we included the hpreact.mean and the hpreact.std? But what do you think our model are doing when making predictions? It should make predictions based on only that input, independently, right? And again, what our model is doing is rather unnatural and strange, as somehow we're taking the batches into consideration when predicting a single instance.
So, in order to preserve the independence of the data when evaluating, we should have an population mean and std to normalize our data. One way to implement this is to wait until the end of the training and run the whole training set again, but this time we calculate the mean and std of the population.
# calibrate the batch norm at the end of training
# Remember to use torch.no_grad because we won't backpropagate through this during training, this will save a lot of memory
with torch.no_grad():
# pass the training set through
emb = C[Xtr]
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 # + b1
# measure the mean/std over the entire training set
bnmean = hpreact.mean(0, keepdim=True)
bnstd = hpreact.std(0, keepdim=True)
And then in the evaluation mode, we replace the mean and std of the batch with the population mean and std:
hpreact = bngain * (hpreact - bnmean) / bnstd + bnbias
That would solve the problem!
But researcher would not like that, training the whole dataset again just to find the mean and the std is rather laborious, and actually in practice they have a nicer way to approximate those values, right during the training part.
The technique here is the keep a exponential moving average (EMA) throughout training. We constantly update the values of mean and std when each batch arrives, and the level of update is determined by the hyperparameter momentum, which is the beta in the formula below:
Implementing the Running Mean and Standard Deviation
Now let's turn back to our code for the (nearly) last optimization. When training out model, we have two jobs to do:
- We need to calculate the mean and std for each batch, and then we normalize the values just like we did when implementing BatchNorm, this is for training.
- We also have to update two values during this process, which are later used in testing (or inference): The
bnmean_runningand thebnstd_running, calculated by the exponential moving average. Remember that these are not the parameters, so set thetorch.no_grad()to them.
With everything prepared, we should jump right in the code. First we initialize the running mean and std.
# Initialize the running mean and std
bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))
Now let's change the BatchNorm layer a bit:
# BatchNorm layer
# -------------------------------------------------------------
# This is for normalization
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
# This is later used for testing, notice the torch.no_grad()
# We set the momentum = 0.999
with torch.no_grad():
bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi
Now we should take those running mean and std into our evaluation mode:
@torch.no_grad() # this decorator disables gradient tracking
def split_loss(split):
...
hpreact = bngain * (hpreact - bnmean_running) / bnstd_running + bnbias
h = torch.tanh(hpreact) # (N, n_hidden)
logits = h @ W2 + b2 # (N, vocab_size)
loss = F.cross_entropy(logits, y)
print(split, loss.item())
And that's basically it! We can see our approximation of mean and std, compared with the real bnmean and bnstd we calculated earlier:
>>>bnmean_running
tensor([[-2.3338, 0.6988, -0.9011, 0.9966, 1.0906, 1.0759, 1.7426, -2.1253,...
>>>bnmean
tensor([[-2.3145, 0.6885, -0.9134, 0.9972, 1.0878, 1.0841, 1.7470, -2.1102,...
The approximations were really good! And let's see our final training and validation loss:
train 2.0666308403015137
val 2.1051523685455322
Great!
Removing the Bias
One small thing to note here, maybe the last thing, is that we can actually remove the bias if we're gonna applying BatchNorm. This is simply because when normalizing, we don't really care about our data being offset by some value, because we subtract everything by the mean anyway. (Think of the bias as something that make the distribution move around left and right, and it is meaningless in normalization because we will eventually move our distribution to center around zero)
So let's just comment out the initialization of b1, which contributes to better memory usage:
#b1 = torch.randn(n_hidden, generator=g) * 0.01
If you scrutinize some advanced models using BatchNorm, like the Resnet, you will notice that when creating a Linear layer, they all set the bias = False.
Summarize
When I first learned about BatchNorm in the book, I didn't even know why we need to normalize stuff, and I got so confused at that time. But now when actually implementing a model of my own (well it's from Andrej, sorry if it bothers you), I realized that BatchNorm is a really significant milestone that solves a lot of problems in the neural net, and those are really easy to notice.
Today we fixed a lot of bugs our models, as well as bringing the holy BatchNorm to light, we went from some problems with Initialization, the Vanishing Gradient, to BatchNorm and its implementation, its pros and cons, as well as learning about the EMA, removing the bias,...
Well, another productive day of us! Thanks for reading and have a good day.










Top comments (0)