Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / neural-network

Handwritten Digits Reader UI

5.00/5 (16 votes)
8 Jan 2019CPOL5 min read 39.5K   2.6K  
A C# object oriented Neural Network, trainer, and Windows Forms user interface for recognitions of hand-written digits.

Introduction

This article is a continuation of another article about the basics in the machine learning area. The first part, which might be good to read first, explains mathematically how a Neural Network learns using a simple six neuron net in Excel.

See Machine Learning in Excel1

This source demonstrates how to train and use a Neural Network to interpret hand written digits.
There won't be any math in this article. It is all about a C#-implementation of basic machine learning.

Image 1

Background

The Artificial Neural Network, ANN, is trained using the Mnist handwritten digits dataset2. This is a classic problem in the field of data science. It is also known as the Hello World application of Machine Learning. There are already a few demo applications posted on Code Project on this subject, but I thought my source could help someone. Complex problems might need more than one explanation and I tried to make it as simple as possible.

Using the Code

Download, unzip and open the solution in Visual Studio 2017.

The Solution

Image 2

The solution contains five projects:

Project Description Framework
DeepLearningConsole Entry point for the training console .NET Core 2.2
DeepLearning Main library .NET Standard 2.0
Data Parser for data. Currently only Mnist. .NET Standard 2.0
MnistTestUi User interface for manually testing .NET Framework 4.7.1
Tests A few unit tests .NET Core 2.2

Why So Many Frameworks?

I realized that .NET Core performs ~30% faster than .NET Framework and I wanted to use .NET Framework for the Windows Forms application.
Unfortunately, you can't reference .NET Core libraries from .NET Framework. They are not compatible.
To solve that issue, I used .NET Standard for the common components.

.NET Standard isn't a framework. It is a formal specification of .NET APIs.
All .NET implementations should be compatible with it, not just .NET Core and .NET Framework but also Xamarine, Mono, Unity and Windows Mobile.
That is why it's a good selection of Target Framework for reusable components.

Parsing Mnist Data

These are the four files in the Mnist database:

  • t10k-images-idx3-ubyte - Test images
  • t10k-labels-idx1-ubyte - Labels for test images
  • train-images-idx3-ubyte - Training images (60000)
  • train-labels-idx1-ubyte - Labels for 60000 train images

The images from Mnist had to be converted from byte arrays to arrays of doubles ranging from 0 to 1.
The files also contain some header-fields.

C#
private static List<Sample> LoadMnistImages(string imgFileName, string idxFileName, int imgCount)
{
    var imageReader = File.OpenRead(imgFileName);
    var byte4 = new byte[4];
    imageReader.Read(byte4, 0, 4); //magic number
    imageReader.Read(byte4, 0, 4); //magic number
    Array.Reverse(byte4);
    //var imgCount = BitConverter.ToInt32(byte4, 0);

    imageReader.Read(byte4, 0, 4); //width (28)
    imageReader.Read(byte4, 0, 4); //height (28)
    var samples = new Sample[imgCount];

    var labelReader = File.OpenRead(idxFileName);
    labelReader.Read(byte4, 0, 4);//magic number
    labelReader.Read(byte4, 0, 4);//count
    var targets = GetTargets();

    for (int i = 0; i < imgCount; i++)
    {
        samples[i].Data = new double[784];
        var buffer = new byte[784];
        imageReader.Read(buffer, 0, 784);
        for (int b = 0; b < buffer.Length; b++)
            samples[i].Data[b] = buffer[b] / 256d;

        samples[i].Label = labelReader.ReadByte();
        samples[i].Targets = targets[samples[i].Label];
     }
     return samples.ToList();
}

The parsing process produces two lists of training and testing samples.
A sample consists of the image pixel array and an target array of length 10 which is the information on which digit the image is.
The digit zero is the array: 1,0,0,0,0,0,0,0,0,0.
Digit five is: 0,0,0,0,1,0,0,0,0,0 (a one in fifth position) and so on.

Instantiating and Train the Neural Network

To instantiate a new ANN, you need to provide its topology, number of layers and number of neurons in each layer.
Input must have 784 neurons for mnist images (28x28 pixels). Output layer must have 10.
The trainer class trains a Neural Network using TrainData with the specified learn rate.

C#
var neuralNetwork = new NeuralNetwork(rndSeed: 0, sizes: new[] { 784, 200, 10 });
neuralNetwork.LearnRate = 0.3;
var trainer = new Trainer(neuralNetwork, Mnist.Data);
trainer.Train();

Next, each training sample is fed to the network so it can learn.
I found out that 200 neurons in the hidden layer made it possible to train the ANN to 98.5% accuracy which seemed good enough.
With 400 neurons, the accuracy maxed out at 98.8% but it took twice the time to train.

Mnist Trainer

The trainer repeatedly trains the ANN letting it see one training sample at the time.
One loop of all 60000 training images is called an epoch.

After each epoch, the ANN is serialized and saved to a file.
The ANN is then tested against the test samples and the result is logged to a csv-file.
The training images are also shuffled between each epoch.

C#
public void Train(int epochs = 100)
{            
    var rnd = new Random(0);
    var name = $"Sigmoid LR{NeuralNetwork.LearnRate} HL{NeuralNetwork.Layers[1].Count}";
    var csvFile = $"{name}.csv";
    var bestResult = 0d;
    for (int epoch = 1; epoch < epochs; epoch++)
    {
        Shuffle(TrainData.TrainSamples, rnd);
        TrainEpoch();                
        var result = Test();
        Log($"Epoch {epoch} {result.ToString("P")}");
        File.AppendAllText(csvFile, $"{epoch};{result};{NeuralNetwork.TotalError}\r\n");
        if (result > bestResult)
        {
            NeuralNetwork.Save($"{name}.bin");
            Log($"Saved {name}.bin");
            bestResult = result;
        }
    }
 }

The theory of the two following chapters are described in a previously posted article.

See Machine Learning in Excel.

Forward Pass

This calculates each neurons value by summarizing all previous neurons multiplied with their weights.
The value is then passed through an activation function. The result or the output can then be obtained from the last Layer, a.k.a. the output layer.

C#
private void Compute(Sample sample, bool train)
{
    for (int i = 0; i < sample.Data.Length; i++)
        Layers[0][i].Value = sample.Data[i];

    for (int l = 0; l < Layers.Length - 1; l++)
    {
        for (int n = 0; n < Layers[l].Count; n++)
        {
            var neuron = Layers[l][n];
            foreach (var weight in neuron.Weights)
                weight.ConnectedNeuron.Value += weight.Value * neuron.Value;
        }

        var neuronCount = Layers[l + 1].Count;
        if (l + 1 < Layers.Count() - 1)
             neuronCount--; //skipping bias

        for (int n = 0; n < neuronCount; n++)
        {
            var neuron = Layers[l + 1][n];
            neuron.Value = LeakyReLU(neuron.Value / Layers[l].Count);
        }
    }
}

Back Propagation

This algorithm adjusts all the weights between neurons. It makes the network learn and gradually improve its performance.

C#
private void ComputeNextWeights(double[] targets)
{
    var output = OutputLayer;
    for (int t = 0; t < output.Count; t++)
        output[t].Target = targets[t];

    //Output Layer
    foreach (var neuron in output)
    {
        neuron.Error = Math.Pow(neuron.Target - neuron.Value, 2) / 2;
        neuron.Delta = (neuron.Value - neuron.Target) * (neuron.Value > 0 ? 1 : 1 / 20d));
    }
    this.TotalError = output.Sum(n => n.Error);

    foreach (var neuron in Layers[1])
    {
        foreach (var weight in neuron.Weights)
            weight.Delta = neuron.Value * weight.ConnectedNeuron.Delta;
    }
    
    //Hidden Layer
    Parallel.ForEach(Layers[0], GetParallelOptions(), (neuron) => {

        foreach (var weight in neuron.Weights)
        {
            foreach (var connectedWeight in weight.ConnectedNeuron.Weights)
                weight.Delta += connectedWeight.Value * connectedWeight.ConnectedNeuron.Delta;
            var cv = weight.ConnectedNeuron.Value;
            weight.Delta *= (cv > 0 ? 1 : 1 / 20d);
            weight.Delta *= neuron.Value;
        }

    });

    //All deltas are done. Now calculate new weights.
    for (int l = 0; l < Layers.Length - 1; l++)
    {
        var layer = Layers[l];
        foreach (var neuron in layer)
            foreach (var weight in neuron.Weights)
                weight.Value -= (weight.Delta * this.LearnRate);
    }
}

Mnist Test UI

The Test UI is used for testing your own handwriting. It has two panels. The small panel interprets a single drawn digit and in the larger at the bottom, you can draw a number.

Image Preprocessing

The Mnist database homepage states that:

"The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field."

Below is the instruction how you can do it using bitmaps and Windows Forms graphics.

Image 3

First, find the smallest square around the drawn digit.

C#
public Rectangle DrawnSquare()
{
    var fromX = int.MaxValue;
    var toX = int.MinValue;
    var fromY = int.MaxValue;
    var toY = int.MinValue;
    var empty = true;
    for (int y = 0; y < Bitmap.Height; y++)
    {
        for (int x = 0; x < Bitmap.Width; x++)
        {
            var pixel = Bitmap.GetPixel(x, y);
            if (pixel.A > 0)
            {
                empty = false;
                if (x < fromX)
                    fromX = x;
                if (x > toX)
                    toX = x;
                if (y < fromY)
                    fromY = y;
                if (y > toY)
                    toY = y;
            }
        }
    }
    if (empty)
        return Rectangle.Empty;
    var dx = toX - fromX;
    var dy = toY - fromY;
    var side = Math.Max(dx, dy);
    if (dy > dx)
        fromX -= (side - dx) / 2;
    else
        fromY -= (side - dy)/ 2;

    return new Rectangle(fromX, fromY, side, side);
}

Crop out the square and resize to a new bitmap of size 20x20.

C#
public DirectBitmap CropToSize(Rectangle drawnRect, int width, int height)
{
    var bmp = new DirectBitmap(width, height);
    bmp.Bitmap.SetResolution(Bitmap.HorizontalResolution, Bitmap.VerticalResolution);

    var gfx = Graphics.FromImage(bmp.Bitmap);
    gfx.CompositingQuality = CompositingQuality.HighQuality;
    gfx.InterpolationMode = InterpolationMode.HighQualityBicubic;
    gfx.PixelOffsetMode = PixelOffsetMode.HighQuality;
    gfx.SmoothingMode = SmoothingMode.AntiAlias;
    var rect = new Rectangle(0, 0, width, height);
    gfx.DrawImage(Bitmap, rect, drawnRect, GraphicsUnit.Pixel);
    return bmp;
}

And finally, draw the 20 by 20 image with its center of mass centered inside a 28x28 bitmap.

C#
public Point GetMassCenterOffset()
{
    var path = new List<Vector2>();
    for (int y = 0; y < Height; y++)
    {
        for (int x = 0; x < Width; x++)
        {
            var c = GetPixel(x, y);
            if (c.A > 0)
                path.Add(new Vector2(x, y));
        }
    }
    var centroid = path.Aggregate(Vector2.Zero, (current, point) => current + point) / path.Count();
    return new Point((int)centroid.X - Width / 2, (int)centroid.Y - Height / 2);
}

protected DirectBitmap PadAndCenterImage(DirectBitmap bitmap)
{
    var drawnRect = bitmap.DrawnRectangle();
    if (drawnRect == Rectangle.Empty)
        return null;

    var bmp2020 = bitmap.CropToSize(drawnRect, 20, 20);

    //Make image larger and center on center of mass
    var off = bmp2020.GetMassCenterOffset();
    var bmp2828 = new DirectBitmap(28, 28);
    var gfx2828 = Graphics.FromImage(bmp2828.Bitmap);
    gfx2828.DrawImage(bmp2020.Bitmap, 4 - off.X, 4 - off.Y);

    bmp2020.Dispose();
    return bmp2828;
}

And then, just extract the bytes from the image and query the ANN with them.

C#
public byte[] ToByteArray()
{
    var bytes = new List<byte>();
    for (int y = 0; y < Bitmap.Height; y++)
    {
        for (int x = 0; x < Bitmap.Width; x++)
        {
            var color = Bitmap.GetPixel(x, y);
            var i = color.A;
            bytes.Add(i);
        }
     }
     return bytes.ToArray();
}

The UI has also a function to show Mnist images if you are curious about what they look like. But I won't go too much into every detail of the UI, because I feel we are getting off topic.

Finally

I hope you liked my text, and perhaps learned something you didn't already know. If you have any questions, comments or ideas, just drop them here below.

That's all for now, don't forget to vote. Cheers!

Links

  1. Machine Learning in Excel - Kristian Ekman
  2. Mnist handwritten digits dataset - Yann LeCun, Corinna Cortes, Christopher J.C. Burges

History

  • 7th January, 2019 - Version 1.0

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)