Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / machine-learning

Implementing Gradient Descent in Go

0.00/5 (No votes)
11 Dec 2019CPOL4 min read 2.2K  
How to implement Gradient Descent in Go
In this tutorial, we will implement the powerful Gradient descent algorithm to implement Simple Linear Regression. Instead of Rust, we use Go.

Some time ago, I showed you how to implement Simple Linear Regression using statistical methods. In this tutorial, we will implement the powerful Gradient descent algorithm to achieve the same goal. Instead of Rust, we are going to use Go this time.

You can find the entire source code for this project on https://github.com/phillikus/go_linear_regression.

Requirements

  • Basic knowledge of the Go programming language
  • Basic high school math

Understanding Gradient Descent

Consider the grid below:

which can be represented as a list of points:

Bash
[(1, 1), (2, 3), (3, 2), (4, 3), (5, 5)]

Based on these points, the goal of the Gradient Descent algorithm is to find the line that comes closest to all these points.

In 2D-space, such a line can be described by the following function:

Bash
y(x) = m*x+n

where n is the so called intercept (y at x==0) and m is the coefficient.

The Gradient descent algorithm finds the optimal line by iteratively updating m and n until they converge to the optimal solution:

Loss Function

To update m and n, we use the so called loss function. Based on m and n, this loss function returns a value representing how close we are to the optimal line. Our goal is to minimize this value so it gets as close as possible to zero.

For linear regression, we will use the Mean Squared Error (MSE) function to get the loss at each step. The MSE is represented by the following formula:

Here, yi represents the actual values and y´i the values we predict at each step. We can substitute y´ with y(x) from above:

To minimize this cost function, we have to calculate its partial derivatives with respect to m and n:

We will use both of these functions to update m and n in each iteration of our algorithm, getting closer to our optimal line each time. With the basics covered, let’s dive into the implementation.

Implementation

To keep the algorithm separate from the code using it, create a new gradient_descent.go file and add a public Regression function with the following signature:

Python
func Regression(
    x_values []float32, 
    y_values []float32, 
    epochs int, 
    learning_rate float32) (float32, float32) {..}

The first two arrays represent our data points, epochs represent the amount of iterations our algorithm should perform before returning a result. The bigger this number is, the more accurate our algorithm will be. However, it will also take more time to finish. Feel free to play around with this number, for our purposes, a value between 100-200 should be more than sufficient.

The learning_rate represents how fast our algorithm will learn with each iteration. With a big learning_rate, our algorithm will converge to an optimal value faster, but we run into the risk of overshooting and never reaching an optimal value.

Again, try to play around with this value to see what works for you. For me, a learning_rate of 0.05 worked really well.

Inside this Regression method, we’ll initialize our m and n varibales with 0:

Python
var m_current float32 = 0
var n_current float32 = 0

To get as close as possible to the optimal value for both m and n, we have to iterate until we reach epochs and update both values at each step:

Python
for i := 0; i < epochs; i++ {
   n_current, m_current = step(n_current, m_current, x_values, y_values, learning_rate)
}

The step function is responsible for running our actual algorithm. It will iterate through all the values of our input data (x_values and y_values) and update m and n using the derivates from above:

Python
func step(n_current float32, m_current float32, 
  x_values []float32, y_values []float32, learning_rate float32) (float32, float32) {
	var n_gradient float32 = 0
	var m_gradient float32 = 0

	var length = len(y_values)

	for i := 0; i < length; i++ {
		var two_over_n = float32(2) / float32(length)
		n_gradient += -two_over_n * 
                      (y_values[i] - ((m_current * x_values[i]) + n_current))
		m_gradient += -two_over_n * x_values[i] * 
                      (y_values[i] - ((m_current * x_values[i]) + n_current))
	}
	...
}

Right after the for loop, we can calculate the new (and better) m and n values by subtracting the new gradients from the current m and n values. To control the rate at which our values jump, we multiply the gradients with the learning rate:

Python
var new_n = n_current - (learning_rate * n_gradient)
var new_m = m_current - (learning_rate * m_gradient)

Now, we only have to return our new m and n values so the loop in the Regression function can continue optimizing our gradients:

Python
return new_n, new_m

Testing the Algorithm

To make sure everything works, add a main.go file to the project. We will use it to initialize some x and y values and then call our Regression method:

Python
package main

import "fmt"

func main() {
	x_values := []float32{1, 2, 3, 4, 5}
	y_values := []float32{1, 3, 2, 3, 5}

	intercept, coefficient := Regression(x_values, y_values, 200, 0.05)

	fmt.Printf("Intercept: %f\n", intercept)
	fmt.Printf("Coefficient: %f\n", coefficient)
}

Now, use go build from Terminal to build the project and run the generated binary file. You should see output similar to this:

Bash
Intercept: 0.394524
Coefficient: 0.801517

If we plot these values into the y = mx + n function, we end up with y = 0.801517*x + 0.394524. Plotting this line into our grid, we should end up with the following graph:

As you can see, this line is very close to every point of the grid, so we reached a pretty good solution. In this case, it’s impossible of course to find a straight line that goes through all the points.
Try out different x and y values and see what results you can get.

Conclusion

That’s it! Feel free to play around with the source code and change some of the parameters. Let me know if you have any questions, hints or comments.

Thank you for reading!

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)