Gradient Descent Optimization

GradientDescent_3

In many machine learning algorithm, the goal is to find a function or parameters that allows us to approximate or modelize unknown observable data. Those data could come from device measurement, web crawling, empirical observations etc. Generally speaking we have samples of the observation vector . For example such a vector could be the coordinates of an object in space.  We want to approximate these data with some parameters through a known function so that

One common way to do that is trying to minimize an error function, for example the root mean square (RMS)

 

 Linear Regression Example

Without loss of generality, we'll focus on the simple case of linear regression. For example, imagine you have a dataset  of points and you want to fit a line

GradientDescent_1

The regression function have then the form of

but more generally for a linear regression this function will have the form of

where is the input vector, the vector on which we want to regress, the observable variables, and  are the regression parameters. To be more generic, we can also define vector to be  with so that the regression function become now

We are looking for the vector that minimizes the error function . The error function is continuous and differentiable so that at the minimal error, we have

As there is no analytic solution of this optimization problem in the general case, one technic is to iteratively update the weights in the direction of the gradient

where is the learning rate. The learning rate has to be small enough not to pass trough the minimum or oscillate around it, but large enough not to converge rapidely. In the special case of linear regression with a polynom of order 2, this takes the form

Practically, we don't look for that satisfies , but we iterate over as long as the vector  significantly changes, or we can also fix the number of iterations. I propose here an implementation of such an algorithm in Python

GradientDescent_2

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: