Linear Regression Using A Single Neuron Continued, The Learning Process
Using Gradient Descent To Minimize Error Function
In the last post we went over linear regression with one variable, and how you can use it make predictions. We also looked at how you can calculate error to tell how well your linear model is following the data. Now we need to figure out how to get the best linear model, ie. draw the best line.
This post will be a little more math-y than the past few, but if you can make it through, you should have no problem dealing with the rest of the math of neural networks. There is a bit of calculus to fully understand how the learning process works, but I will do my best to give some intuition behind these equations that make most peoples eyes gloss over, including my own.
Remember Lisa? She is dreamer. She dreams of drawing the perfect line through a dataset. Shoot for the stars Lisa! We are going to show how you can start with any random line, and without knowing where you are going to exactly end up, slowly work your way to the perfect line.
This seems like a trivial task. You may be thinking: "Hey I can whip out my ruler and draw that line no problem, what's the big whoop?". Well you are right, but I am giving the most simple example possible, so that we can work out the solution by hand. Imagine a larger dataset, with hundreds of points. Could you draw a perfect line through that?
Sure you could get pretty close, and that's all well and good, but what if you had more than 1 dimension for the input? What if you had 100 different variables that gave you 1 output? We can't even visualize what this would look like in our primitive 3-dimensional brains. Luckily for you, I am going to go through the math for a trivial model, and a small dataset. The same exact math, with the same exact steps will apply to very complicated datasets, and complicated models. At the end of this post you will be able to write a computer program that can do all the complicated math for you!
In fact, although this post deals with a single neuron, in an another post we will chain these neurons together in a larger neural network, and use the same math to train it. The nice thing about how these neural networks are designed is that they are very modular. You can work out the math for one little module at a time, then stack them all together, and they will start influencing each other and learning together.
Back to where we left off in the last post. To better understand exactly what Edgar and his error function is doing, let’s draw a graph of it. If you recall, given all the data points in a dataset, the mean squared error function is:
In order to graph this function, first let’s rewrite it a little. Remember Edgar has two inputs, the output from "Lisa the Linear Function", and the target output value from the data.
We will call Lisa’s output x_0 and the target value y_0. The subscript 0 is just telling us that this is the first datapoint in our dataset. Since we are only using one datapoint we can ignore the summation, and put the mean squared error equation in terms of this first datapoint:
We have named this function "e" for error function. Since technically the most inner parenthesis of the equation is output of Lisa, let’s call this L(x0) and factor this out as well. Below is a graph of Lisa's current line, and some of the variables marked for clarity.
You might start to recognize the error equation is starting to look similar to the equation for a parabola:
Let's plot this equation on google to verify it is a parabola and see what it looks like. Notice the "y" value simply moves the parabola on the x-axis.
This means every time Lisa passes her output to Edgar, and he runs her output through his function, we are going to get a value some on a parabola.
With a general idea of what the mean squared error function looks like, let’s think back to the problem. We want to minimize this output, giving the lowest error possible. This means every time Lisa passes Edgar a value, we want it to be closer to the bottom of the parabola than last time. The problem is Edgar does not have access to change Lisa’s weights, and cannot change her output. We need a way for Edgar to communicate back to Lisa how to change her weights, only given what he is seeing. It is important to keep our system modular and separate in this way so it is easier to model in software later.
Edgar is already going to send Lisa her error given the datapoint, but we need to add more to this message. The next thing we are going to need to add is called a derivative. This is a mathematical construct that Edgar can calculate on his own, without knowing anything about Lisa, but still gives Lisa enough information to change her weights in a way that next time she makes a prediction, the error will be lower. In other words, it will contain a value that indicates to Lisa which way to move her output to get closer to the correct answer. This may seem like magic at first, but we will walk through in detail how this works.
If you have not taken calculus, the term derivative may be foreign, but it is actually a pretty simple concept. The derivative of a function is the rate of change of a function at a specific point. In order to find this rate of change, we can draw what is called a tangent line.
A tangent line is a line that "just touches" the curve at that point. The derivative at this point can be visualized as the slope of the tangent.
Another way to think of it, is if you had two points moving along the curve, approaching each other, you take the slope of the line they create when they are infinitely close to one another, and that is the derivative at that point.
Why is a derivative important? How does this get Lisa closer to the correct answer?
Well getting closer to the correct answer from Edgar's point of view means that the error is lower next time. If Edgar looks at the error every time he calculates it, he knows where he would prefer the next value to be, given the same input. He can quantify this as the derivative at this point, and relay this number back to Lisa.
If we take little steps in the direction of the derivative each time, we will eventually reach the bottom of the curve, where error is approximately equal to 0.
Following the slope of a function down to a minimum value is called Gradient Descent. You can think of gradient descent as a hiker who wants to get to the bottom of a mountain but cannot see the trail. He scans the land around him, and takes one step in the direction that is steepest downhill. Every time Lisa adjusts her weights so that the error gets smaller, she is taking a step downhill, and each time she will get a more optimal line.
One important thing to note about gradient descent is that you should treat the derivative as the direction in which to move, and not the magnitude. In our case Edgar will tell Lisa whether to move her line up or down, but not how far to move it. If you take too big of steps in the direction, you might miss the minimum value of the function.
Not only might you miss the minimum value of this minimum value, in this exponential function, if you take too big of steps, you might miss the minimum and oscillate off into infinity. This should make sense given our graph of Lisa and her target value, if you move the line up too far, the error is just going to be bigger next time. If you only move it up a little bit, it will get closer to the actual data point.
In general, if you take smaller steps, there is a better chance you will get to the minimum error, it just might take you a few more steps. The size of this step is called the learning rate, and we will look at it in more detail later.
Calculating DerivativesNow that we have a basic understanding derivatives, you may be asking how do I calculate them? If you have taken calculus, you may remember there are a few rules for derivatives that are useful to remember. If you haven’t taken calculus, don’t worry, we are going to go over the intuition for the few rules that we need for our problem.
The first rule is pretty simple. It is called the power rule. If you have an equation with x raised to an exponent, and you want to take the derivative with respect to x, you simply move the exponent in front of x, then subtract 1 from the exponent.
In our case the exponent goes to 1, leaving us with 2x.
There are two ways we will write a derivative mathematically, the first way looks like a fraction:
The way you say this is "the derivative of y with respect to x equals 2x". It is always stated "the derivative of numerator, with respect to derivative of denominator". The other way you may see derivatives is with a tick mark. This is called prime notation.
Here we say y-prime of x is 2x. It means the same thing as "the derivative of y with respect to x equals 2x".
Don’t take my word for it, let’s see if this derivative makes sense given our graph of x^2 and what we know about derivatives. Say we have the point x=2, this would mean y=4.
Drawing the tangent and calculating it’s slope gives us.
The tangent line goes up 8 and over 2, giving us the slope of 4. This checks out with our equation for the derivative dy/dx = 2x seeing as x=2 and 2*2 = slope of 4.
What about for x = 1? This would give us y = 1 because 1^2 = 1.
The slope of the tangent here is up 2 and over 1 or just 2. This also checks out of our equation dy/dx = 2x. It is looking like the power rule checks out, but feel free to try more points on the graph.
We are going to need more than the power rule to fully calculate the derivative of the mean squared error function. The next couple rules are pretty intuitive and you could have probably guessed them from the definition of a derivative.
The constant rule says that the derivative of a constant value is zero. This is pretty obvious seeing is the slope of the line of y = any constant with be 0 / c = 0.
Next if you take the derivative of a linear function, the derivative is just the slope of the line, or the value that x is being multiplied by.
We also have addition and subtraction in the MSE equation, so let’s look at the rules for these operations.
First there is the "sum rule" for addition.
In the case of mean squared error, we aren’t adding functions, but we are adding constants. If you add a constant to a function, it is not going to change the slope. We already know that the derivative of a constant is 0, so this rule seems to check out.
The same is true for subtraction, this is called the "difference rule”.
The same intuition should apply here.
The last rule we need has a little less intuition, but will become very important once we dive into deeper neural networks. This rule deals with when you have a composition of functions, or a function that takes another function as input, and is called the “chain rule".
In this case the function “h” depends on the output of “g”. You may recognize that this is exactly what the interaction between Lisa and Edgar looks like. Lisa passes the output of her function as the input to Edgars function.
The derivative for such chains of functions can be computed by multiplying the derivative of the outer function by the derivative of the inner function.
Maybe I will add some intuition for this later, but for now The Khan Academy has some good videos on it.
Derivatives in PracticeWe should now have all the equations we need to calculate our derivatives. Hopefully this was a good refresher if you have taken calculus, and a good jumping off point you a jumping off point if you have not. There will be a lot of math and equations in this section, none of which are particularly difficult, but I find it can still be hard to follow. The hardest thing for most people to follow is where all the variables came from, given all the subscripts and derivative notation. It might be good to open two windows side by side with this page loaded so that every time you see a variable in an equation, you can trace it back to where it came from and refresh your memory of why we are using it.
Have your two windows open? Let's see these derivatives in action! To make it concrete, we are going to work with an actual dataset. With Lisa and Edgar working together, they should be able to incrementally make progress on finding the correct line.
Let’s take the small dataset from before, with only two points [(2,3), (4,5)]
Lisa randomly chooses a starting weight of -0.5 and bias of 2.5 to construct her line. She hasn’t seen any data yet, so has not learnt what these values should be.
Initially the line clearly below the two points, and the slope is in the wrong direction. Let’s see how we can use the error function and it’s derivative to adjust Lisa's weights so it is closer to these two points.
First consider the first data point (2,3). With Lisa’s current weights, we would pass the input of x = 2.0 through her linear function and get the output 1.5.
This means Lisa’s best guess is the point y=1.5 given x=2.
Edgar sees Lisa’s guess, but notices that the real output should be y=3 given x=2. He calculates that the MSE, or the difference between the truth and our prediction squared is 2.25.
Now we know that calculating the mean squared error is only half of Edgar’s job. Now he has to do the second half of calculating his derivative. Remember, the derivative is a number that indicates the direction Lisa should move her line to get closer to the data.
We will use our rules from earlier, and Edgars function, to figure out an equation for the derivative. If you remember, we put Edgar’s equation in terms of Lisa’s input equation L(x_0) and the correct output y0 to see what the graph would look like.
When you have an equation with multiple inputs (as we do here), this means you can calculate the derivative with respect to any of the inputs. This is called a partial derivative, and tells you the impact this variable has on the slope of the tangent. This is why we use the notation dy/dx. “dy” states that “y" is the function we are interested in, and “dx” means x is the value we are interested in.
When passing the derivative back to Lisa, we want to take the partial derivative of the error function e with respect to the value Lisa calculated L(x_0).
When taking partial derivatives. You can treat all the other input variables as constants, in this case y0 will be treated as a constant.
Since L(x_0) is not alone inside the parenthesis of the exponent, we have function inside of a function. This means we are going to have to use the chain rule. Remember the chain rule thats that given:
The derivative is:
Let's define h and g in terms of variables from our mean squared error equation. g(x) will be the inner equation for the chain rule and looks like:
h(g(x)) takes g(x) as input and squares it:
Calculating the derivatives individually we can use the sum rule for the inner part, g’. Where y_0 is a constant and goes to zero, and L(x_0) is linear, where the slope -1.
We then use the power rule for the outer equation h(x)
Leaving us with our equation for the full derivative:
Or in terms of our variables and the error function:
This is the equation we have been looking for. The derivative of the error function with respect to it's input, L(x_0). Let’s plug in the values and send them back to Lisa.
Remember Lisa's guess?
And the error that Edgar calculated?
We should have all the info we need from looking at these graphs and our equation for the derivative.
The input x value:
The correct target value y:
Therefore our derivative:
Perfect, we have our derivative calculated for the first data point x_0. Since we have another datapoint, let’s do the same for x_1, and pass the average derivative back to Lisa.
Here Lisa guesses incorrectly again at y=0.5 for x=4.
This time Edgar calculates an error of 20.25:
Lisa is not very close this time, let’s take a look at the derivatives for this data point x_1. The input was x = 4.0:
Lisa's output was 0.5:
The correct target output was 5.0:
Plugged into our equation for the derivative of the error with respect to Lisa's output:
And we get:
Notice that the derivative for this point is much larger than the first. Intuitively this is because the error is higher, so we will be farther up the parabola, and the slope will be greater.
Now we have considered both points in the dataset, and Edgar can write his message to Lisa:
Your line is a little off. The error for the first point was 2.25, and your error on the second point was 20.25. I have averaged these errors for you for a total average error of 11.25 over the dataset.
I have calculated some derivatives for you, I hope you find them useful. The derivative for the first point with respect to your output was -3.0, and the derivative for the second point with respect to your output was -9.0. Averaging over the dataset we have an average derivative of -6.0.
Hope this information finds you well.
Lisa receives the message, and is a little disappointed with her performance, but knows this is constructive feedback. With this information she can update her weights, and do better next time.
She wants to know how much influence each weight had on the error. So she needs to calculate the derivative of her function as well. This derivative is a rather easy one to calculate.
Remember Lisa’s function has a few variables:
In terms similar to what we were using above:
She wants to figure two separate derivatives. The derivate of L with respect to w0 and the derivative of L with respect to w1.
When thinking of what derivatives you need to take, I find it useful to remember that you want to take the derivative of your output, with respect to the thing you want to change. Edgar took the derivative of e (his output) with respect to L(x), because he wanted L(x) to change. Lisa will take the derivative of L(x) (her output) with respect to her weights, because she wants to change her weights.
Remember, when taking partial derivatives, the rest of the variables in the L(x,b) equation are essentially constants, so we get:
Hopefully this makes sense. Think back to the equation for a line you know and love: y=mx+b. In our case w is the thing that is changing, so you can picture it as y=mw+b. The slope of this line is just “m”, and the constants go to zero. Now look at our equation L(x_i) and try to figure out the slope with respect to w_0 and w_1.
More mathematically, if you use the sum rule, the linear rule, and the constant rule from above, you should get the same answers.
Now that Lisa has her derivatives for her output given w_0 and w_1, she wants to know how to effect the error function so that it is lower. The problem is, she is not the one who computes the error function, Edgar is. In other words not only does she want to know how to change her weights to affect L(x), she wants to change L(x) so that the next modules output (the error) is lower.
In mathematical terms she wants to know de/dw0 and de/dw1.
Since e is a function that took L(x_i) as input, we cannot compute it directly. Luckily with the chain rule we know that when we have functions of functions we simply multiply the derivatives together, in this case:
The chain rule comes in really handy in this modular approach to learning. You don't have to know the function the next module computes, all you have to know is the value of the derivative for the next function, at the point you compute. I know all this notation looks complicated, but look at the fractions we have here. dL(xi) essentially cancels out, making the chain rule a little easier to remember.
When dealing with all these derivatives, my brain starts to melt, and all the notations starts to blur together. Sometimes it is difficult to follow. I personally find it helpful to explicitly say what the derivatives are in my head or out loud, then to think about what that means. Remember, the notation is "derivative of numerator, with respect to denominator". Think about what the variable in the numerator is, and what the variable in the denominator is. Here we want to know the rate of change of e, as we change w_0, and w_1. We want to know the derivative of e with respect to w_0 and the derivative of e with respect w_1. In other words if we change a weight, what is the affect on the error.
Try to trace back where we got all the numbers from and you will realize these complicated looking equations are really just additions, subtractions, and multiplications.
Again let’s consider the first point in our dataset x_0. We want to find the impact on the error w_0 contributed. We can't see the error function, but we have it's derivative from Edgar.
We know from Edgar that:
And from Lisa that dL(x_0)/dw_0 is just equal to the input x_0:
Which leaves us with:
We can do the same for the second weight w_1 while looking at the first data point. Again, Edgar calculated:
The input for the bias is always just 1 so:
It looks like the derivatives for this first point are telling us to change w_0 twice as much as w_1, in other words, change the slope a lot, and the bias a little.
Now we have to consider the second point in our dataset, x_1, and find how our weights played a role in it’s error.
Remember, Edgar already calculated the derivative of the error with respect to L(x_1) above:
And the derivative of L(x_1) with respect to w_0 is easy, it is just equal to x_1 which is 4:
This gives us:
Now time for the derivate of the error with respect to w_1, given our second data point x_1:
Again, Edgar has done work for the first part already:
But Lisa needs to do her part for her weight, and since the bias input is always 1, this is easy:
All and all giving us the derivative of the error with respect to w_1 for our second data point.
Now we have our derivatives for our weights over our entire dataset! Let’s average them out before using them.
Remember how we talked about the learning rate in the process of finding the minimum value of a function earlier? Now it is time to see it in action.
These derivatives tell us over the entire dataset, the direction we want the weights to move. Technically the derivatives tell us the direction in which we should move the maximize the error. Since we want to minimize the error, we are going to subtract this direction from the weight. But how much are we going to subtract? The learning rate moderates this.
The learning rate is a value you multiply against your derivatives before subtracting them from your weights. Think back to the statement we made earlier about taking little steps in the direction of the derivative. If we simply subtracted the derivatives we have here from the weights, this would be too big of a step, even though it is a step in the right direction. This essentially would be a learning rate of 1.0. It would change the slope by 21 and the bias by 6. A slope of 21 is pretty steep, so I won't graph it here, but think about where this line would end up in your head, it would be way past our target value.
If you set the learning rate low, it has a better chance of finding the minimal error, even if it takes a lot of steps to get there. Let's start with a learning rate of 0.01.
This vector notation of derivatives is called "the gradient". This is the reason we call it gradient descent, because we end up with a gradient, and we are using to descend down to a minimum value. We now subtract the gradient from our weights to get a new set of weights that should perform slightly better next time.
Let’s see how our new weights do on our dataset. It looks like the new slope is a little less steep, and the new bias is a little higher on the x-axis.
After all that crazy math, the new green line is Lisa's second guess. It is definitely closer to the target value than last time! If we have Edgar calculate the average error it looks like it will be less than 11.25. In fact if you go through the exercise with this new line, you will see the average error is now 7.0002. This is a good improvement! Great teamwork Edgar and Lisa.
In fact, if you run this process over and over you will see the error dropping each iteration through the dataset:
It would be tedious to do this by hand. It was exhausting working out the first iteration by hand, but we did it so we have a step by step algorithm we can follow for the next iterations. In fact, in the next post we will write a program to do it for us. As a sneak peak, I have plotted the output from this program. We can watch the system learn below:
The dark blue line is the real data we saw on our graph before, and the light blue line is the prediction from our linear model. It starts with the same weights, biases and learning rate as we did in our example, and if you watch the “epoch” counter at the top, you can see it gets closer and closer to an exact fit as we get closer to epoch 1000. An epoch is terminology for seeing all of the data exactly once. We chose a small learning rate, so you can see it slowly moving towards the correct line.
If you are curious about how this works with a higher learning rate, I made a few more references gifs. This next one is with learning rate at 0.05:
You can see it quickly jumps up and over the line after epoch 0, then lines up pretty well by epoch 200. This is about 5x faster than the other learning rate which makes sense.
If you start to get ambitious, and think you can set the learning rate to even higher, to make it learn even faster, BEWARE. Remember, we could miss the local minimum and oscillate out into infinity. To illustrate, here is the learning rate at 0.1 (note: I had to change the y-axis scale so we could see what happens):
In the first couple epochs we over shoot the target, then by epoch 30, we are so far off that we can no longer see our prediction on the graph. This is a classic example of having a learning rate that is too large.
I know this post had a lot of math, but you made it! We could now start with any dataset, and automatically find the best line that fits it (assuming you tinker with the learning rate). You might be thinking.. “What’s the big deal? All we did was fit a line. This isn’t AI, I’m leaving!”. Be patient young Padawan, you must master the fundamentals before moving onto larger problems. After we code up this solution, we will dive into larger neural networks, which build upon everything we did here.
I will walk through the C++ code in the next post. I will be following the exact steps we did above, so there is really nothing new. It would be a good exercise to try to implement the algorithm above in your favorite programming language before looking at my implementation (it really shouldn’t take that long).