Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / neural-network

ReInventing Neural Networks - Part 2

5.00/5 (19 votes)
19 Feb 2018CPOL4 min read 35.5K   3.1K  
In Part 2, the Neural Network made in Part 1 is tested in an environment made in Unity so that we can see how well it performs.

The Full Series

  • Part 1: We create the whole NeuralNetwork class from scratch.
  • Part 2: We create an environment in Unity in order to test the neural network within that environment.
  • Part 3: We make a great improvement to the neural network already created by adding a new type of mutation to the code.

Introduction

A few days ago, I posted this article explaining how you can implement a neural network from scratch in C#. However, in the last article, the neural network was trained on an XOR function. As promised, we're going to train simple cars, in Unity, to drive! Here's what we're aiming for:

After I finished the video, I felt like it was some creepy 90s clip, but it does the job ...

Background

To follow along this article, you'll need to have basic C# and Unity programming knowledge. Also, you're going to need to have to read my previous article where I first implemented the NeuralNetwork class.

Pre-Programming Resources

In case you're new to C#, you can always search the MSDN docs for the stuff you're not familiar with, but in case you want to look for something Unity-specific, you may want to search Unity's Scripting Reference or Unity's Manual instead.

Using the Code

First off, you gotta know all the classes that are going to be used in the project:

  • Car: The main script that controls the movement of the car object (controlled by a NeuralNetwork or by the user).
  • Wall: A simple script that is attached to every wall. It sends a "Die" message to a car if it hits an object with this script on it.
  • Checkpoint: A simple script that increases the fitness(score) of a car once it is hit.
  • EvolutionManager: That script simply waits for all the cars to die, then it makes a new generation from the best car.
  • CameraFollow: That's the function that changes to position of the camera to look at the best car.

Here's how it's all going to work:

  • There is going to be a track with a series of checkpoints along its path.
  • Once a car hits a checkpoint, its fitness increases.
  • If a car hits a wall, it gets destroyed.
  • If all the cars are destroyed, a new generation is created from the best car in the last generation.

Now, we're going to go over each script and explain it in a little more detail.

NeuralNetwork

An entire article was devoted to that one ...

Car

First, we got to have a few variables defined:

C#
[SerializeField] bool UseUserInput = false;     // Defines whether the car
                                                // uses a NeuralNetwork or user input
[SerializeField] LayerMask SensorMask;          // Defines the layer of the walls ("Wall")
[SerializeField] float FitnessUnchangedDie = 5; // The number of seconds to wait
                                                // before checking if the fitness
                                                // didn't increase

public static NeuralNetwork NextNetwork = new NeuralNetwork
       (new uint[] { 6, 4, 3, 2 }, null);       // public NeuralNetwork that refers to
                                                // the next neural network to be set to
                                                // the next instantiated car

public string TheGuid { get; private set; }     // The Unique ID of the current car

public int Fitness { get; private set; }        // The fitness/score of the current car.
                                                // Represents the number of checkpoints
                                                // that his car hit.

public NeuralNetwork TheNetwork { get; private set; } // The NeuralNetwork of
                                                      // the current car

Rigidbody TheRigidbody;                         // The Rigidbody of the current car
LineRenderer TheLineRenderer;                   // The LineRenderer of the current car

That's what we should do whenever a new car is created:

C#
private void Awake()
{
    TheGuid = Guid.NewGuid().ToString(); // Assigns a new Unique ID for the current car

    TheNetwork = NextNetwork;            // Sets the current network to the Next Network
    NextNetwork = new NeuralNetwork(NextNetwork.Topology, null); // Make sure the
       // Next Network is reassigned to avoid having another car use the same network

    TheRigidbody = GetComponent<Rigidbody>(); // Assign Rigidbody
    TheLineRenderer = GetComponent<LineRenderer>(); // Assign LineRenderer

    StartCoroutine(IsNotImproving());   // Start checking if the score stayed
                                        // the same for a lot of time

    TheLineRenderer.positionCount = 17; // Make sure the line is long enough
}

This is the IsNotImproving function:

C#
// Checks every few seconds if the car didn't make any improvement
IEnumerator IsNotImproving ()
{
    while(true)
    {
        int OldFitness = Fitness;                             // Save the initial fitness
        yield return new WaitForSeconds(FitnessUnchangedDie); // Wait for some time
        if (OldFitness == Fitness)              // Check if the fitness didn't change yet
            WallHit();                                        // Kill this car
    }
}

This is the Move function that(wait for it...) "Moves" the car:

C#
// The main function that moves the car.
public void Move (float v, float h)
{
    TheRigidbody.velocity = transform.right * v * 4;
    TheRigidbody.angularVelocity = transform.up * h * 3;
}

Then comes the CastRay function that does a casts and visualises rays. It'll be used later on:

C#
// Casts a ray and makes it visible through the line renderer
double CastRay (Vector3 RayDirection, Vector3 LineDirection, int LinePositionIndex)
{
    float Length = 4; // Maximum length of each ray

    RaycastHit Hit;
    if (Physics.Raycast(transform.position, RayDirection,
                        out Hit, Length, SensorMask)) // Cast a ray
    {
        float Dist = Vector3.Distance
        (Hit.point, transform.position); // Get the distance of the hit in the line
        TheLineRenderer.SetPosition(LinePositionIndex,
        Dist * LineDirection);           // Set the position of the line

        return Dist;                     // Return the distance
    }
    else
    {
        TheLineRenderer.SetPosition(LinePositionIndex,
        LineDirection * Length);         // Set the distance of the hit in the line
                                         // to the maximum distance

        return Length;                   // Return the maximum distance
    }
}

Follows ... the GetNeuralInputAxisFunction that does a lot of the work for us:

C#
// Casts all the rays, puts them through the NeuralNetwork and outputs the Move Axis
void GetNeuralInputAxis (out float Vertical, out float Horizontal)
{
    double[] NeuralInput = new double[NextNetwork.Topology[0]];

    // Cast forward, back, right and left
    NeuralInput[0] = CastRay(transform.forward, Vector3.forward, 1) / 4;
    NeuralInput[1] = CastRay(-transform.forward, -Vector3.forward, 3) / 4;
    NeuralInput[2] = CastRay(transform.right, Vector3.right, 5) / 4;
    NeuralInput[3] = CastRay(-transform.right, -Vector3.right, 7) / 4;

    // Cast forward-right and forward-left
    float SqrtHalf = Mathf.Sqrt(0.5f);
    NeuralInput[4] = CastRay(transform.right * SqrtHalf +
                     transform.forward * SqrtHalf, Vector3.right * SqrtHalf +
                     Vector3.forward * SqrtHalf, 9) / 4;
    NeuralInput[5] = CastRay(transform.right * SqrtHalf + -transform.forward * SqrtHalf,
                     Vector3.right * SqrtHalf + -Vector3.forward * SqrtHalf, 13) / 4;

    // Feed through the network
    double[] NeuralOutput = TheNetwork.FeedForward(NeuralInput);

    // Get Vertical Value
    if (NeuralOutput[0] <= 0.25f)
        Vertical = -1;
    else if (NeuralOutput[0] >= 0.75f)
        Vertical = 1;
    else
        Vertical = 0;

    // Get Horizontal Value
    if (NeuralOutput[1] <= 0.25f)
        Horizontal = -1;
    else if (NeuralOutput[1] >= 0.75f)
        Horizontal = 1;
    else
        Horizontal = 0;

    // If the output is just standing still, then move the car forward
    if (Vertical == 0 && Horizontal == 0)
        Vertical = 1;
}

And, that's what we do 50 times per second:

C#
private void FixedUpdate()
{
    if (UseUserInput) // If we're gonna use user input
        Move(Input.GetAxisRaw("Vertical"),
        Input.GetAxisRaw("Horizontal")); // Moves the car according to the input
    else // if we're gonna use a neural network
    {
        float Vertical;
        float Horizontal;

        GetNeuralInputAxis(out Vertical, out Horizontal);

        Move(Vertical, Horizontal); // Moves the car
    }
}

We also need to have a few functions that're going to be called from other scripts (Checkpoint and Wall):

C#
// This function is called through all the checkpoints when the car hits any.
public void CheckpointHit ()
{
    Fitness++; // Increase Fitness/Score
}

// Called by walls when hit by the car
public void WallHit()
{
    EvolutionManager.Singleton.CarDead(this, Fitness); // Tell the Evolution Manager
                                                       // that the car is dead
    gameObject.SetActive(false);                       // Make sure the car is inactive
}

Wall

The Wall script simply notifies any car that hits it:

C#
using UnityEngine;

public class Wall : MonoBehaviour
{
    [SerializeField] string LayerHitName = "CarCollider"; // The name of the layer 
                                                          // set on each car

    private void OnCollisionEnter(Collision collision)    // Once anything hits the wall
    {
        if (collision.gameObject.layer == LayerMask.NameToLayer(LayerHitName)) // Make sure 
                                                                               // it's a car
        {
            collision.transform.GetComponent<Car>().WallHit(); // If it is a car, 
                                                          // tell it that it just hit a wall
        }
    }
}

Checkpoint

The Checkpoint does almost the same thing as the Wall, but with a twist. Checkpoints use a Trigger instead of a Collider, and Checkpoints also make sure they increase the fitness of each car only once. This is why each Car has a Unique ID. Each Checkpoint simply saves all the Guids of Cars increased before:

C#
using System.Collections.Generic;
using UnityEngine;

public class Checkpoint : MonoBehaviour
{
    [SerializeField] string LayerHitName = "CarCollider"; // The name of the layer set 
                                                          // on each car

    List<string> AllGuids = new List<string>();           // The list of Guids of all the 
                                                          // cars increased

    private void OnTriggerEnter(Collider other)           // Once anything goes through the wall
    {
        if(other.gameObject.layer == LayerMask.NameToLayer(LayerHitName))  // If this object 
                                                                           // is a car
        {
            Car CarComponent = other.transform.parent.GetComponent<Car>(); // Get the component 
                                                                           // of the car
            string CarGuid = CarComponent.TheGuid;        // Get the Unique ID of the car

            if (!AllGuids.Contains(CarGuid))              // If we didn't increase 
                                                          // the car before
            {
                AllGuids.Add(CarGuid);                    // Make sure we don't 
                                                          // increase it again
                CarComponent.CheckpointHit();             // Increase the car's fitness
            }
        }
    }
}

EvolutionManager

You can't write a script without variables:

C#
public static EvolutionManager Singleton = null; // The current EvolutionManager Instance

[SerializeField] int CarCount = 100;             // The number of cars per generation
[SerializeField] GameObject CarPrefab;           // The Prefab of the car to be created
                                                 // for each instance
[SerializeField] Text GenerationNumberText;      // Some text to write the generation number

int GenerationCount = 0;                         // The current generation number

List<Car> Cars = new List<Car>();                // This list of cars currently alive

NeuralNetwork BestNeuralNetwork = null;          // The best NeuralNetwork
                                                 // currently available
int BestFitness = -1;                            // The Fitness of the
                                                 // best NeuralNetwork ever created

On the start of the program:

C#
// On Start
private void Start()
{
    if (Singleton == null) // If no other instances were created
        Singleton = this;  // Make the only instance this one
    else
        gameObject.SetActive(false); // There is another instance already in place.
                                     // Make this one inactive.

    BestNeuralNetwork = new NeuralNetwork(Car.NextNetwork); // Set the BestNeuralNetwork
                                                            // to a random new network

    StartGeneration();
}

That's how a new generation is created:

C#
// Starts a whole new generation
void StartGeneration ()
{
    GenerationCount++; // Increment the generation count
    GenerationNumberText.text = "Generation: " + GenerationCount; // Update generation text

    for (int i = 0; i < CarCount; i++)
    {
        if (i == 0)
            Car.NextNetwork = BestNeuralNetwork; // Make sure one car uses the best network
        else
        {
            Car.NextNetwork = new NeuralNetwork(BestNeuralNetwork); // Clone the best
                                      // neural network and set it to be for the next car
            Car.NextNetwork.Mutate(); // Mutate it
        }

        Cars.Add(Instantiate(CarPrefab, transform.position,
                 Quaternion.identity, transform).GetComponent<Car>()); // Instantiate
                                      // a new car and add it to the list of cars
    }
}

Stuff called by the Cars:

C#
// Gets called by cars when they die
public void CarDead (Car DeadCar, int Fitness)
{
    Cars.Remove(DeadCar);        // Remove the car from the list
    Destroy(DeadCar.gameObject); // Destroy the dead car

    if (Fitness > BestFitness)   // If it is better that the current best car
    {
        BestNeuralNetwork = DeadCar.TheNetwork; // Make sure it becomes the best car
        BestFitness = Fitness;   // And also set the best fitness
    }

    if (Cars.Count <= 0)         // If there are no cars left
        StartGeneration();       // Create a new generation
}

CameraFollow

Just another simple all-in-one script that does the job:

C#
using UnityEngine;

public class CameraFollow : MonoBehaviour
{
    Vector3 SmoothPosVelocity;      // Velocity of Position Smoothing
    Vector3 SmoothRotVelocity;      // Velocity of Rotation  Smoothing

    void FixedUpdate ()
    {
        Car BestCar = transform.GetChild(0).GetComponent<Car>(); // The best car in 
                                                                 // the bunch is the first one

        for (int i = 1; i < transform.childCount; i++)           // Loop over all the cars
        {
            Car CurrentCar = transform.GetChild(i).GetComponent<Car>(); // Get the component 
                                                                        // of the current car

            if (CurrentCar.Fitness > BestCar.Fitness) // If the current car is better than 
                                                      // the best car
            {
                BestCar = CurrentCar;                 // Then, the best car is the current car
            }
        }

        Transform BestCarCamPos = BestCar.transform.GetChild(0); // The target position 
                                                      // of the camera relative to the best car

        Camera.main.transform.position = Vector3.SmoothDamp
               (Camera.main.transform.position, BestCarCamPos.position, 
                ref SmoothPosVelocity, 0.7f);         // Smoothly set the position

        Camera.main.transform.rotation = Quaternion.Lerp(Camera.main.transform.rotation,
                                         Quaternion.LookRotation(BestCar.transform.position - 
                                         Camera.main.transform.position),
                                         0.1f);       // Smoothly set the rotation
    }
}

Points of Interest

Now that we have all the scripts explained in detail, you can sleep well knowing that the NeuralNetwork class previously implemented works well and is not a waste of time. It felt really good to see those cars learning how they can drive through the track so step-by-step. Also, the car uses built-in sensors, which means that the car can drive on tracks it didn't learn driving on before. Once I got that done, I felt like my binary children were learning how to drive! I tried as hard as I can to make this implementation as simple as possible for people who don't wanna deeply dig into Unity's stuff. And... never think for a bit that we're done here. My current target is to implement 3 Crossover operators to make evolution a bit more efficient and offer the developer more diversity. After that, Backpropagation is the target.

Update on 20th February, 2018

Part 3 is up and running! It shows a substantial improvement over the system discussed in Parts 1 and 2. Tell me what you think!

History

  • 11th December 2017: Version 1.0: Main Implementation

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)