Improving Our Neural Network Outputs With Softmax Classification
How to predict probabilities from a linear layer, to achieve higher accuracy on the MNIST dataset.
In the last post we wrote a function to calculate accuracy on the MNIST test data set in order to track how well our neural network was learning. We were able to achieve ~90% accuracy on the test set with our last configuration, but there was something that seemed off in the output. The prediction values had an infinite range, and didn't really have a meaning behind them.
It would be much nicer if we could output a probability distribution that summed to one, giving us the probability that the input was a 1 vs. a 7. This would match our target output, since technically a one hot encoding is also a probability distribution. This is usually done with a function called the "softmax" function.
This function takes the output of a n dimensional vector, and rescales them so that they are in a range from (0,1) and sum to 1. It is called the "softmax" function because it highlights the larger values in the input and suppresses the smaller ones. The fact that all the values sum to 1 make it a good function to model a probability distribution.
In this example it has taken the 10 outputs from our network, which have an unconstrained range, could be positive and negative, could be greater than 1, and squashed them between 0,1. This output is now saying there is a 23% chance the output is a 1 and a 17% chance it is a 7, and a lower probability for all others.
Having a function that constrains the range to the same range of our target output should help our neural network become more accurate. Let's implement a softmax layer and see if this is indeed the case. Add the header in our layers directory.
We will implement the same interface as all of our layers so far, with our Forward() and Backward() methods.
To get a better sense of how the softmax layer should behave, let's write a unit test. Add the file tests/softmax_layer_test.cpp
For the inputs [1,2,3] we should get the probability distribution [0.09, 0.24, 0.6652]. You can see the math in the comments.
Now let's write the implementation and try to make this unit test pass.
Following the equation from above, we sum up the all of the inputs run through an exponential function, then divide each input (e^x) by the sum. This should make our first unit test pass.
There are some numerical stability issues with the softmax as implemented above. As explained here. The sum of the exponents can become very large in the denominator.
Let's see this in action with another unit test.
Taking the exponent of numbers larger than 700 will generate numbers that are not supported by most computer architectures (ie. it will give you infinity or nan).
Just look at the graph for e^x
Even at x=23 y=1.5*10^10 and it just gets steeper from there. To fix this numerical instability, we simply subtract the max value from the row from each number. This way, the max value is zero, and every other number is negative.
We will need another function in our tensor class called MaxVal(), but this is pretty simple since we already have MaxIdx().
This should finish up our forward pass! Feel free to run the unit tests to make sure everything still works.
Next up we will have to take the derivative of the softmax function. I'm not going to lie, this is probably the most involved derivative we have done yet. For the sake of brevity, I am going to link you a couple blogs here and here that have all the math for the derivative and just give you the code for it here. Eventually we will update our code to do automatic differentiation so we will no longer need to take derivatives, but that's a task for another day.
Note that we are re-running the Forward() computation in the backward pass. This is because the derivative requires it, and all of our interfaces so far have simply passed in the original inputs to the Backward() method. Eventually we will standardize on a new interface for Backward() that is a little more flexible and allows for caching of any variables we want, but for now, the inefficient method will work.
This is all we need to add our softmax layer to the end of our network and get probabilities instead of arbitrary linear values. Let's head on over to our training loop to make it happen. First add our header.
Then define our new layer in the main function.
Then update the training loop to use our new layer after the second linear layer.
Finally we will pass our new layer into our CalcAccuracy function as well.
With everything in place, it is time to see our net in action! You should see us doing better than random by the 10,000th example. My training accuracy was around 21% and test accuracy around 17%.
Look at the outputs from our network now! They are all between 0-1 and are summing to exactly 1. If you look at this particular example, it is deciding between 3 and 7, and ended up having the highest probability of 22% for 7, but 3 was right behind with the second highest probability at 13%. This is the exact behavior we were expecting. Now the network can start nudging these probabilities in the correct direction, and we will see the accuracy rising.
By the end of the first epoch we have a training accuracy of 90.2 and test accuracy of 85.7, and it seems to be getting more confident in it's answers, in this case predicting with a 99% probability that the digit is a 6.
Note that it is normal to see a higher training accuracy than test accuracy, since the test examples are ones we have never seen before. If the gap between train and test accuracy is too big, this means we are "overfitting" to our training set, and this is not good. We will talk about this, and how to fix it in a later blog post.
By the end of 10 epochs the accuracy is above 95% on the train and above 91% on the test. This is better than we did without the softmax if you remember back to last post (last time we were at train = 93% and test = 89%).
Now that we have the softmax function outputting proper probabilities, we can switch to an even better loss function for classification tasks. This is called the cross-entropy loss.
Our current loss function, mean squared error, is really more meant for regression tasks, where you are trying to minimize the difference between a specific value and the target value. In our case we just happen to be using 1 and 0 as these two values. We really have classification task on our hands, where we are trying to decide between 10 classes.
The cross entropy loss is defined as:
Where p(x) is the true probability distribution (ie. our one-hot encoded target) and q(x) is the estimated probability distribution.
When we write it out in code we will give some examples to show why this makes sense. Add our new header for the loss, and use the same interface we did for mean squared error.
Then implement the function above for all the inputs and targets in our Forward() pass.
In the comments you can see how this error function would work in the most extreme cases where y=0 and yhat=1, etc. Work through this math and make sure it makes sense to you. Feel free to plot out the log(x) function on google to see how it behaves at different values.
Next we will need to add the backwards pass calculating the partial derivatives for the equation above. The derivative of the natural log of a function is simply 1/x, leaving us with the following implementation.
Since the MeanSquaredErrorLoss class and the CrossEntropyLoss classes have the same interface, we can simply swap them out in our training loop.
Once you do, you will see we are now achieving an A+ on the training set and a solid A on the test set!
We are now reaching an accuracy that is close to what researchers have been seeing on the MNIST website. We are not quite state of the art, but we would be somewhere on the leaderboard. If we wanted we could try different hidden layer sizes, different learning rates, and get even high accuracy.
This is a lot of good work! We bumped up our accuracy ~7 percent with some minor changes to our training regiment. It always gets harder and harder to gain a percent in accuracy the closer you get to 90% and above, so you should feel great about the progress. If you want the source code so far, there is a git repo for this blog post here.