Writing The Code For Precision And Recall
Diving into how well we are performing on MNIST
After reading the last section, it should be clear why accuracy alone does not give us the full picture of how well our model is doing. In this section, we will implement the code for a generic "metrics" interface, and implement this interface with classes that can precision, recall, and accuracy, given our outputs and targets. Feel free to clone a clean version of the last repository here as a starting point.
By the end of this section, you will see that we can achieve about ~90% precision at ~90% recall if our confidence threshold is set at 0.5. You can skip this section and just use the completed code if you are comfortable with the concepts covered in the last section, but I always recommend writing the code to let it all sink in :)
Let's start by making a new directory in our includes, and adding the interface file metrics.h.
Then let's add the methods we want our metrics to be able to handle. GetName() is simply for debugging and logging the results.
AddResults() is for storing the outputs and the targets over time. For example we will want to calculate the precision over the entire test set in between epochs of training. In order to do this, we will call AddResults(inputs, outputs) each time we run a forward pass, and calculate the totals at the end.
Calculate() will simply take whatever results we added, and calculate the metrics over the data, given a confidence threshold (which defaults to 0.5).
In order to store the results, we will want a data structure that can easily flush out early data, and keep the latest. The deque container from the standard library is perfect for this. Add in two protected data members that our subclasses can have access to.
Next let's add in the implementation for our only non pure virtual function AddResults, which will simply add the results to the end of the deque.
This is a good starting point for our interface and base class functionality. Now let's add in our Precision and Recall subclasses.
Let's start with precision:
Now we will simply have to implement the two methods GetName() and Calculate().
I have just filled in the stubs here, because we should add some unit tests before we implement the rest of the functionality. This will help us understand if our implementation is correct, as well if our interface is reasonable.
Here we are pretending that we have 4 data points, that belong to 3 different classes. Each row of l_outputs contains a confidence distribution for an example. For example the first row indicates a confidence of 0.75 that the correct answer is of the first class (0th column). Each row of l_targets is a one hot encoding of where the correct answer lies.
In this case, if we took the maximum value of each row of l_outputs.at(row_idx), and checked if that index matched l_targets.at(row_idx, max_idx), we would see that we got two answers correct, but only one of them surpasses the threshold of 0.7. We also got two incorrect, but they were also both below the threshold. This leaves us with 1 true positive, 2 false negatives, and 1 true negative. Since precision is defined as P = tp / (tp + fp) , we get a precision of 1.0. Walk through the comments for more examples to really hammer it home.
Let's hop back into the implementation to make this test pass. Based on our definitions of precision and recall from last time, it is clear that all of our metrics can be calculated from the 4 values of TP, FP, TN, FN. It would make sense to put these calculations in the base class, so that all metrics can use them.
Starting with an easy one, let's calculate the true positives in the dataset.
This will work, but if you start going through all the other functions to calculate the other values, you will see that they all follow a similar pattern, just with a different comparator function for whether or not it is true positive vs false positive etc...
Luckily, C++ lets you pass in a function as a parameter, so lets abstract this logic into a more generic function called p_IterAndCountWithFn() and pass it a function called p_ExampleIsTruePositive() to calculate whether an example is a true positive.
Now we can move the bulk of our work to p_IterAndCountWithFn() and shorten p_CalcNumTruePositives() to a couple lines.
The syntax looks a little wonky when passing in a function as a parameter, especially when it belongs to a class. Just remember, all we are really doing here is passing in a comparator function with 4 parameters to p_IterAndCountWithFn() and letting it do all the heavy lifting.
This makes the rest of our calculations easy peasy, we just need to define the new comparators, and implement them.
** Warning the implementations are short but repetitive, if you understand the first example, feel free to just copy paste **
Whew, that was a mouthful, but at the end of the day it should make the sub-classes short and sweet. Let's look at how straight forward Calculate() on the Precision class becomes.
This is very readable for some one who simply wants to see the formula for precision, and can look at the base class for the nitty gritty details.
Now implementing recall will be quite easy, we just need to remember the formula with respect to true positives and false negatives. The header for recall.h is very similar to the header for precision.h
Then the calculate function in the implementation will simply use the formula
Recall = TP / (TP + FN)
Go ahead and write some similar unit tests for our Recall class to verify it all works as expected. Hop back over to the definitions from the last post if you need inspiration for tests.
Finally, we will want an implementation for accuracy that also uses our definitions of tp,fp,tn & fn. I'll spare you the repetitive code (I think you can figure it out from the interfaces above) and just give the implementation for Calculate().
You'll notice here that accuracy is not concerned with the confidence threshold at all, it is just looking at the max value of each prediction and assuming that is correct like we did when we first trained the net.
Great, hopefully you have tested each one of these implementations and everything is making sense. Now we can use them in our main training loop to evaluate Precision, Recall, and Accuracy on the test set after each epoch.
If you take a look at tools/feedforward_net/main.cpp you'll see we were simply calculating correct / incorrect over our test set to calculate accuracy. Let's change this function to use our new classes.
Here we are passing in each layer in our model, as well as a set of confidence thresholds we want to compute precision and recall for. We instantiate our Precision and Recall classes (we could also do Accuracy if we wanted), and then run forward passes over all the examples in the test set to accumulate results. Once all the outputs and probabilities have been calculated and accumulated, we can test the precision and recall at each confidence level.
Down in our training loop, we will want to use our new Accuracy class, and new RunOnTestSet function to print out how well the net is doing.
I let my network train for about 500 epochs, and was able to get an accuracy of ~79%, with a precision and recall curve that looks like this:
Or if we graphed it out
It looks like the optimal precision & recall is around the confidence threshold 0.5. This would give us Recall of ~90% with precision of ~88%. We could get up to 99% precision if we raised the confidence threshold to 0.75, but this comes at the cost of 33% recall.
Adding a confidence threshold is the way to go in real world applications, and can be good downstream signal to other systems depending on your application. In the next post we will look at different activation functions that may help increase performance.