Evaluation Pt 2: Precision And Recall Curves
Gaining a better understanding of how well our network is performing
In our previous posts we have only been using the "accuracy" metric to determine how well our network is doing. There are a few more useful metrics to understand how well a classification system is performing. Imagine a case where you have a very imbalanced data set.
If we put our decision boundary where the light blue line is and computed our accuracy we would be doing reasonably well. There are 20 total data points and we have correctly classified 18 of them, leaving us with an accuracy of 90%! We got an A in the class, we can go home right? Well not quite.
If you look at it, sure we are doing really well on the red data points, but pretty terrible on the blue data points. We only got one out of three blue data points correct. In fact, if we moved our guess way off into "no man's land" we would still be getting 85% correct simply by always guessing it should be a red data point.
There are often times in machine learning where we are looking for a needle in a haystack. We are looking for one blue data point in a sea of red ones. If we always guessed "red" in a sea of "red data", sure, we might be 99% accurate, but this is not useful for the task at hand.
This is why understanding your dataset and metrics is important. Simply shooting for 99% accuracy is never a good idea. There are a few more metrics we can add in to solve for this, namely precision and recall.
Precision can be thought of as "given how many guesses you made, how many were correct?". Recall can be thought of as "given how many correct answers there were, how many did you guess on?".
Let's put this in terms of our MNIST problem. For each image in the MNIST database our network will produce a probability distribution for which number it thinks is contained in the image. Logging a single example we mights see the outputs and the targets as something like the following:
This output is saying the net is 29.5% confident the image contains a 4, and 27.1% confident the image contains a 9, and pretty low confidence in all other values. It is easy to imagine why the net may be confused on this examples because 9s and 4s are pretty similar in shape.
When looking at SoftMax output distributions, it is often good to put a "confidence threshold" in for whether we should guess this value at all. Think about when the net is first learning from the data. It is producing outputs that look like this:
Pretty much every guess is around 10% confident, because the network has no idea what it is seeing yet. Sure we might get some answers correct if you look at the max value, but should we really count it as correct if it is only 11% confident? Probably not.
A distribution where the highest value is much higher than the second highest is much more convincing that the net isn't just randomly guessing. For example this next output is 55% confident the answer is 4 and only 22% confident it is a 9.
Taking this intuition about our outputs and targets into account, let's go back to our definitions of precision and recall.
Let's say we define a confidence threshold of 0.5, meaning the network has to be 50% confident in this answer for us to even consider it. This means we will not even guess on images we are not sure on, but when we do guess, we are much more likely to be correct. This means our precision will be higher, given our definition of "out of the guesses you made, how many were correct?".
What about recall? Recall was defined as "given how many correct answers there were, how many did you guess on?". Since we are now not even guessing on images where we are not confident, this means our recall will be lower. A plot of precision vs. recall at different confidence levels may look like the following:
You may have a precision of 85%, but this may mean your recall is only 30%. Meaning when you do guess, you are right 85% of the time, but you only found 30% of the correct answers.
It is always a balance finding the optimal precision and recall. It usually depends on the problem you are trying to solve. For example, imagine you are making decisions in a hospital, that effect peoples lives. You will probably tend towards high precision diagnosis, rather than high recall. It is important that when you tell someone they have cancer, you are right. It would be really bad if you told someone they had cancer and they didn't, just because you wanted to improve your probability of finding someone with cancer.
On the other hand, consider a search engine like Google. It is probably more important to them to have high recall, rather than high precision.