Evaluation Pt 2: Precision And Recall Curves
Gaining a better understanding of how well our network is performing
In our previous posts we have only been using the "accuracy" metric to determine how well our network is doing. There are a few more useful metrics to understand how well a classification system is performing. Imagine a case where you have a very imbalanced data set.
If we put our decision boundary where the light blue line is and computed our accuracy we would be doing reasonably well. There are 20 total data points and we have correctly classified 18 of them, leaving us with an accuracy of 90%! We got an A in the class, we can go home right? Well not quite.
If you look at it, sure we are doing really well on the red data points, but pretty terrible on the blue data points. We only got one out of three blue data points correct. In fact, if we moved our guess way off into "no man's land" we would still be getting 85% correct simply by always guessing it should be a red data point.
There are often times in machine learning where we are looking for a needle in a haystack. We are looking for one blue data point in a sea of red ones. If we always guessed "red" in a sea of "red data", sure, we might be 99% accurate, but this is not useful for the task at hand.
This is why understanding your dataset and metrics is important. Simply shooting for 99% accuracy is rarely a good idea. There are a few more metrics we can add in to solve for this, namely precision and recall.
Precision vs. Recall
Precision can be thought of as "given how many guesses you made, how many were correct?". Recall can be thought of as "given how many correct answers there were, how many did you find?".
Let's put this in terms of our MNIST problem. For each image in the MNIST database our model will produce a probability distribution for which number it thinks is contained in the image. Logging a single example we mights see the outputs and the targets as something like the following:
This output is saying the net is 29.5% confident the image contains a 4, and 27.1% confident the image contains a 9, and pretty low confidence in all other values. It is easy to imagine why the net may be confused on this examples because 9s and 4s are pretty similar in shape.
When looking at SoftMax output distributions, it is a good idea to put in a "confidence threshold" for whether we should guess at all. If we look at the guesses the network is making when it first learning from the data, we would likely see a distribution like this:
All the guesses are around 10% confident, because the network has no idea what it is seeing yet. It is assigned about equal probability to every class. Sure we might get some answers correct if you look at the max value, but should we really count it as correct if it is only 11% confident? Probably not.
If we have a distribution where the highest value is much higher than the second highest is much more convincing that the net isn't just randomly guessing. For example this next output is 55% confident the answer is 4 and only 22% confident it is a 9.
Taking this intuition into account, let's go back to our definitions of precision and recall.
Let's say we define a confidence threshold of 0.5, meaning the network has to be 50% confident in this answer for us to even consider it. This means we will not guess on images we are not sure on, but when we do guess, we are much more likely to be correct. This will make our precision will be higher. Remember we defined precision as "out of the guesses you made, how many were correct?".
What about recall? Recall was defined as "given how many correct answers there were, how many did you guess on?". Since we are now no longer guessing on images where we are not confident, this means our recall will be lower. A plot of precision vs. recall at different confidence levels may look like the following:
You may have a precision of 85%, but this may mean your recall is only 30%. Meaning when you do guess, you are right 85% of the time, but you only found 30% of the correct answers.
It is always a balance finding the optimal precision and recall. It usually depends on the problem you are trying to solve. For example, imagine you are making decisions in a military setting. If you are trying to decide whether or not to take a shot at an enemy, you want to make darn sure you know it is the enemy. In this case you would want a high precision decision, rather than high recall. It is important that when you take a shot, you are right, and you are not targeting a civilian. It would be bad if the system shot at everyone, just because you wanted to improve your probability of hitting an enemy.
On the other hand, consider a search engine like Google. It is probably more important for Google to have higher recall. It is fine if the top web page is not exactly what you are looking for, it is more important that the answer is somewhere in the top results. Search engines generally prefer higher recall, with less focus on precision.
With all this in mind, let's dive into how do we calculate recall and precision. During a classification task, there are 4 possible of outcomes from a prediction. These outcomes are a true positive, a false positive, a true negative, and a false negative. We can use these 4 outcomes to calculate the recall and the precision. Don't worry if these terms sound confusing at first, we will define them in terms of our MNIST task to make it clear.
MNIST as a Game Show
We can think of classification tasks as buzzing in during a game show like Jeopardy! In the MNIST task the game would be "who can identify the image correctly, in the shortest amount of time".
During inference our neural network will output a probability distribution that represents how likely each image is of a certain digit. We can look at this distribution and define a threshold at which we are "confident enough" to buzz in. If we are confident enough, say over 50%, we will buzz in. If we are not over 50% certain, we will keep our thoughts to ourself.
True Positives (TP)
The first possible outcome is called a true positive. This is when you have a high confidence in your prediction and you are correct. For example if the network was shown an image of a 9, and predicted a 9 with a confidence of 0.95, this would be a true positive. Keep in mind the confidence threshold. If you guessed with a confidence of 0.45, and the confidence threshold was 0.5, this is what is called a false negative (which we will cover below).
False Positive (FP)
A false positive is when we are very confident in a result, but it is in fact incorrect. If for example the network had a confidence of 0.95 that the image was a 3, but in fact it was an 8 (a very common mistake), this would constitute a false positive. The network guessed, and it guessed wrong.
True Negative (TN)
If your highest confidence answer is incorrect, but your confidence level is lower than the threshold, this is a true negative. A true negative can be thought of as "when you are correct for holding your tongue, because you know you are wrong". Imagine you are playing Jeopardy! and you literally have no idea what the question is asking about. It is probably better to not buzz in rather than guess a random city in Japan.
False Negatives (FN)
A false negative is when none of your answers exceed the confidence threshold, but the highest confidence value is still in fact correct. Think of this as when you have a gut feeling about something, but really aren't sure if it is correct, so you don't buzz in. Then you later find out it was the right answer. This may cause you to want to lower your confidence threshold to increase your recall, but remember that would most likely cause precision to go down as well.
Using TP,FP,TN,FN to Calculate Precision and Recall
Now that we understand the four types of outcomes with a classification task, we can simply count up how many times each one of them happens, and come up with our precision and recall scores.
Precision is defined as TP / (TP + FP). In other words the number of true positives divided by the total number of true and false positives. Looking back at our definitions we can see that the word "positive" indicates whether or not we made a guess, and "true" or "false" simple means were we correct or incorrect. This holds up with our initial definition of precision as "given how many guesses you made, how many were correct?".
Recall on the other hand is defined as TP / (TP + FN). Here we are taking the number of true positives, and dividing by the sum of the true positives and false negatives. The numerator here is again all of our correct guesses, when we did guess. The denominator this time the total number answers you in theory could have gotten correct. Remember a false negative is when you knew the answer, you just were not confident enough in it.
Notice neither precision nor recall takes into account true negatives, where you did not guess and you were not correct anyways. In a giant dataset there can be a lot of true negatives, and they often do not give a lot of signal about how well you are classifying the data. Imagine searching for a needle in a haystack and reporting every time you "did not find a needle". Someone yelling "found hay! found hay! found hay! found hay!" would be quite annoying to work with.
There is another definition of accuracy that takes true negatives into account (TP + TN) / (TP + TN + FP + FN). This can be useful for very well balanced datasets.
Now that we understand what precision and recall are, and how they can help us get a better picture into how well our classification task is doing, in the next post we write the code for them and integrate some "Metrics" classes into our library. Look forward to seeing you there!