Using Mini Batches During Training

How to see more data at once, and improve our neural networks even further.

At this point, we have trained a neural network to recognize hand written digits with ~96.5% accuracy. This is great, but there is one more optimization we can make to make our neural networks faster, and generalize better. This is called "mini batching".

Our code so far (which you can grab from here), has been learning off a single example at a time. In practice when training neural networks, it is common to group a set of random examples together into what is called a "mini batch" in order to take advantage of the parallel nature of the matrix multiplications, and neural net architecture.

To see how this works, let's recall what our input currently looks like. We are taking 28x28 black and white images, and converting them to vectors of 1x784 to feed into our neural network. This 1x784 vector is then passed through our first linear layer, which performs a matrix multiplication of 1x784 * 784x300. As we learned earlier, as long as the inner most dimensions match during a matrix multiplication, it is valid. The resulting matrix will be the corresponding outer dimensions. This means that we can actually stack multiple inputs together into a matrix, and still leave our first layer the same. This is called a batch of size 4. The resulting output will now be 4x300 instead of 1x300. This is ok though, because if you look at our next linear layer, it is 300x10. This means it does not matter the batch size, the math still lines up.

If we organize our inputs into these mini batches, our computation will be much faster, since the part of our that is parallelized is the matrix multiplications that take place on the input.

Let's see what this looks like in code. We will have to update our data loader to grab 4 random examples now and put them into a batch. There is a reason we originally organized our data loader into two methods before: DataAt() and DataLength(). This makes it so we can have a generic base class with batching functionality. We will make it so as long as any subsequent data loader implements these two APIs, it will get batching for free.

Make a new base class called Dataloader in include/neural/data/dataloader.h

As you can see, it has our two virtual functions that it will leave to the child class to implement. It also has some member variables that we will get into when we implement the GetNextBatch() function. Go ahead and add the implementation file src/dataloader.cpp.

The GetNumBatches() function will tell our main loop how many iterations to run, given a batch size. It simply uses the DataLength() function from the subclass to know how many batches there will be. Our constructor simply initializes some default values.

Next we will implement the GetNextBatch() function. You may have noticed the m_shouldRandomize boolean and vector of indices in our member variables. These deal with our randomization logic. For unit testing purposes, we do not always want to randomize the data. If we do want to randomize, we will simply keep track of this with a random vector of indices from 0..DataLength().

Now that we have this vector, we can keep a pointer into it to know which random data to grab next, without grabbing the same data twice. This is what our m_currentIdx variable is for. The rest of the implementation is as follows.

We first create a tensor of batchSize x input size. Then we use the child class's DataAt() method to populate this tensor. There is a new helper function here in our tensor class for setting a row in a tensor, called SetRow. This will only work for tensors that are of a valid matrix shape.

Let's add both SetRow() and GetRow() to our tensor class to make this happen.

SetRow() will take an index and a tensor as input, and set the row at the index to the values in the tensor. The implementation will first test if we have a matrix, then make sure the input tensor is valid. After this we compute an offset given the number of columns in the row tensor. It then skips to this offset and copies the data in.

GetRow() is pretty similar logic, using the offset to copy data into a new tensor that is the size of the row.

We can now make our MNISTDataloader class a subclass of Dataloader to get the batching functionality.

Notice we mark the virtual methods with "override". This will ensure at compile time that we have over-ridden the correct functions. We can now update our training loop to use mini batches instead of a single data point at a time.

At the top of our loop, let's define a variable batchSize to specify how big we want our mini batches to be.

We will use a batch size of 100, which means our net now learns from 100 examples at a time! This means we will only have 60000/100 = 600 iterations per epoch. Let's update our logging and accuracy calculation to only occur once every 100 iterations, and once every epoch.

We should also update our CalcAccuracy code to use the batch size.

It should now only take about 25 seconds to process an epoch, where as before it took about 2 minutes. If you compare the logs between our code from before without mini batches, and the code with, you will notice as a function of examples seen, the accuracy is not as high. In other words, our accuracy after the first epoch is much lower with the mini batching code than our single example per batch code.

This is ok though, we are seeing 100 examples at a time and essentially averaging their gradients during the learning process. After ~20 epochs we will be back over 70% accuracy on the test set, and eventually reach the same accuracy as seeing one example at a time. You can download the full code from this example here.

Batching is great when deploying your code in the real world, because it will help process many more data (in this case images) per second. You will almost never see neural networks processing single examples at a time, so getting accustomed to mini batching in your data loaders and models will become second nature. In the next post we will go over a few more metrics such as precision and recall to help get a better sense of how well our networks are performing on classification tasks.