Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / machine-learning

My Little Learner

4.45/5 (4 votes)
18 Dec 2017CPOL2 min read 8.2K  
Simple implementation of k-NN classifier

Introduction

This tip is a very simple implementation of k-Nearest Neighbor classifier in Machine Learning.

Background

I am currently watching YouTube playlist by Google Developers channel about machine learning. This tip is my attempt to understand what is taught in the playlist. The playlist can be found here if you are interested.

The Algorithm

k-NN algorithm is a pattern recognition algorithm. Here, k number of nearest neighbors are identified and based on the values of these neighbors, prediction is made. The nearest neighbors are identified using Euclidean distance.

For instance, let's say we have records containing petal lengths and widths for different flowers. Now, based on this data, we can try to predict the type of flower based on petal length and width. Here, length and width are referred to as features. A feature can be thought of as a coordinate. Now, when user provides a length and a width, we find k (user input) number of nearest coordinates and then based on flowers these neighbors represent, we predict the flower.

In the image below, if K were to be 3, and we had to predict which flower is represented by the point in black, we find three nearest points. These points will be A, B and C. Now, since A and B are I. versicolor (majority), we would predict the point in black to be I. versicolor.

Iris Distribution

I was having a discussion with a friend of mine and he told me that this works better with increase in number of features. However, with increase in number of features, it also becomes tougher for us to visualize. I was told that this algorithm is preferred while doing imputation in the real world.

The Classifier

The classifier is a simple class file that implements the algorithm. It does perform the following:

  • Save the feature and value names
  • Save the training data - initial data load
  • Perform the prediction
  • Save the new prediction to database if it was correct

It works with MongoDB with 2 types of documents. One will have names of features and value fields represented by the following JSON:

JavaScript
{ "0" : "Sepal Length", "1" : "Sepal Width", "2" : 
  "Petal Length", "3" : "Petal Width", "Value" : "Species" }

There will be only one document of this kind in the database.

Another type of document will be the actual data which is stored in the database. Since I am fetching all the documents from the collection, the _id field of Mongo is left auto-generated.

JavaScript
{ "0" : 5.0, "1" : 3.6, "2" : 1.4, "3" : 0.3, "v" : 0.0 }

The order of data in these documents is the same as the order in features document.

The code is fairly trivial IMHO so I assume a code dump alone would suffice.

C#
using MongoDB.Bson;
using MongoDB.Driver;
using MyLittleLearner.Helpers;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Linq;

namespace MyLittleLearner.Classifiers
{
    /// <summary>
    /// This class implements K-NN classifier
    /// </summary>
    internal class KNN
    {
        /// <summary>
        /// This method saves the training data to the database
        /// </summary>
        /// <param name="data">2-d array with data to store. 
        /// First dimension will represent record number and second dimension will have 
        /// values for feature. Last element for each row is the value to seek</param>
        internal void SaveDataStore(double[,] data)
        {
            try
            {
                var collection = MongoHelper.GetCollection
                                 (ConfigurationManager.AppSettings["DataStoreCollection"]);

                collection.DeleteMany(new BsonDocument());

                var rowCount = data.GetLength(0);
                var colCount = data.GetLength(1);

                for (var row = 0; row < rowCount; row++)
                {
                    var valuesForBSON = new Dictionary<string, object>();

                    for (var column = 0; column < colCount - 1; column++)
                    {
                        valuesForBSON.Add(column.ToString(), data[row, column]);
                    }

                    valuesForBSON.Add("v", data[row, colCount - 1]);

                    var document = new BsonDocument(valuesForBSON);
                    collection.InsertOne(document);
                }
            }
            catch (Exception exception)
            {
                File.AppendAllText("Logs.txt", DateTime.Now.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);
                File.AppendAllText("Logs.txt", exception.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);

                throw new Exception("Database exception.");
            }
        }

        /// <summary>
        /// This method saves the names of features and values
        /// </summary>
        /// <param name="features">Names of features</param>
        /// <param name="value">Name of value</param>
        internal void SaveFeatureNames(string[] features, string value)
        {
            try
            {
                var collection = MongoHelper.GetCollection
                                 (ConfigurationManager.AppSettings["FeatureNameCollection"]);
                collection.DeleteMany(new BsonDocument());

                var valuesForBSON = new Dictionary<string, object>();

                for (int index = 0; index < features.Length; index++)
                {
                    valuesForBSON.Add(index.ToString(), features[index]);
                }

                valuesForBSON.Add("Value", value);

                var document = new BsonDocument(valuesForBSON);
                collection.InsertOne(document);
            }
            catch (Exception exception)
            {
                File.AppendAllText("Logs.txt", DateTime.Now.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);
                File.AppendAllText("Logs.txt", exception.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);

                throw new Exception("Database exception.");
            }
        }

        /// <summary>
        /// This method predicts the value based on the feature values and k value
        /// </summary>
        /// <param name="features">Array of feature values in same order as initial input</param>
        /// <param name="k">Number of neighbors to compare with</param>
        /// <param name="isExactMatch">
        /// Flag to specify whether the exact features were located in database</param>
        /// <returns>Value prediction</returns>
        internal double Predict(double[] features, int k, out bool isExactMatch)
        {
            double returnValue = 0;
            try
            {
                isExactMatch = false;

                // Get the collection from database
                var collection = MongoHelper.GetCollection
                                 (ConfigurationManager.AppSettings["DataStoreCollection"]);

                // Prepare filter BSON based on user input
                BsonDocument filterBSON = new BsonDocument();

                for (int index = 0; index < features.Length; index++)
                {
                    filterBSON.Add(index.ToString(), features[index]);
                }

                // Check if there is an exact match in database
                var exactMatch = collection.Find(filterBSON).FirstOrDefault();

                if (exactMatch == null)
                {
                    // No exact match found, get all documents from database
                    var documents = collection.Find(new BsonDocument()).ToEnumerable();

                    double currentDistance = 0;
                    var currentPrediction = new SortedDictionary<double, double>();
                    double minDistance = 0;

                    // Based on K, need to track k nearest neighbors
                    // Start with first k records and assume them to be nearest
                    foreach (var currentItem in documents.Take(k))
                    {
                        currentDistance = CalculateDistance(currentItem, features);

                        currentPrediction.Add
                             (currentDistance, currentItem.GetElement("v").Value.AsDouble);
                    }

                    // Get the maximum distance from first k records
                    minDistance = currentPrediction.Keys.Last();

                    // Compare rest of documents
                    foreach (var currentItem in documents.Skip(k))
                    {
                        currentDistance = CalculateDistance(currentItem, features);

                        if (minDistance > currentDistance)
                        {
                            // If maximum distance in selected K records is more than current, 
                            // remove last element and reset distance to compare with
                            currentPrediction.Remove(currentPrediction.Last().Key);
                            currentPrediction.Add(currentDistance, 
                                                  currentItem.GetElement("v").Value.AsDouble);
                            minDistance = currentPrediction.Keys.Last();
                        }
                    }

                    // Get the most common value and return
                    returnValue =
                        currentPrediction.Values.GroupBy(x => x)
                            .OrderByDescending(x => x.Count())
                            .Select(x => x.Key)
                            .First();
                }
                else
                {
                    // Notify caller that exact match was located in database
                    isExactMatch = true;
                    returnValue = exactMatch.GetElement("v").Value.AsDouble;
                }

            }
            catch (Exception exception)
            {
                File.AppendAllText("Logs.txt", DateTime.Now.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);
                File.AppendAllText("Logs.txt", exception.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);

                throw new Exception("Database exception.");
            }

            return returnValue;
        }

        /// <summary>
        /// This method calculates Eucledian distance between input and data stored
        /// </summary>
        /// <param name="currentItem">Current item from database</param>
        /// <param name="features">Features from user input</param>
        /// <returns>Distance between the current item and user input</returns>
        private double CalculateDistance(BsonDocument currentItem, double[] features)
        {
            double distance = 0;
            try
            {
                for (int index = 0; index < features.Length; index++)
                {
                    distance = distance +
                                Math.Pow((currentItem.GetElement(index + 1).Value.AsDouble - 
                                features[index]), 2);
                }
            }
            catch (Exception exception)
            {
                File.AppendAllText("Logs.txt", DateTime.Now.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);
                File.AppendAllText("Logs.txt", exception.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);

                throw new Exception("Unable to calculate distance.");
            }
            return Math.Pow(distance, 0.5);
        }

        /// <summary>
        /// This method saves the feature values and prediction value to database
        /// </summary>
        /// <param name="features">Array of feature values in same order as initial input</param>
        /// <param name="value">Correct value for given features</param>
        internal void Confirm(double[] features, double value)
        {
            try
            {
                var collection = MongoHelper.GetCollection
                                 (ConfigurationManager.AppSettings["DataStoreCollection"]);

                var valuesForBSON = new Dictionary<string, object>();

                for (int index = 0; index < features.Length; index++)
                {
                    valuesForBSON.Add(index.ToString(), features[index]);
                }

                valuesForBSON.Add("v", value);

                var document = new BsonDocument(valuesForBSON);
                collection.InsertOne(document);

            }
            catch (Exception exception)
            {
                File.AppendAllText("Logs.txt", DateTime.Now.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);
                File.AppendAllText("Logs.txt", exception.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);

                throw new Exception("Database exception.");
            }
        }
    }
}

You might have noticed the class MongoHelper. It is a simple class that provides a method to get the collection object from database and another one to get the first document with given filter. Here's all the code in MongoHelper.

C#
using MongoDB.Bson;
using MongoDB.Driver;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Linq;

namespace MyLittleLearner.Helpers
{
    /// <summary>
    /// This class provides the methods to perform operations on MongoDB
    /// </summary>
    class MongoHelper
    {
        /// <summary>
        /// This method returns the MongoDB collection with given name
        /// </summary>
        /// <param name="collectionName">Collection name to search</param>
        /// <returns>MongoDB collection</returns>
        internal static IMongoCollection<BsonDocument> GetCollection(string collectionName)
        {
            var client = new MongoClient(ConfigurationManager.AppSettings["URL"]);
            var database = client.GetDatabase(ConfigurationManager.AppSettings["Database"]);
            var collection = database.GetCollection<BsonDocument>(collectionName);

            return collection;
        }

        /// <summary>
        /// This method gets the first document from the given collection based on given filter
        /// </summary>
        /// <param name="collection">Collection to search</param>
        /// <param name="filterBSON">Filter BSON; optional</param>
        /// <returns>Dictionary with first document returned. 
        /// Key: Field name Value: Field value</returns>
        internal static Dictionary<string, object> 
                 FetchFirstDocument(string collection, BsonDocument filterBSON)
        {

            Dictionary<string, object> result = new Dictionary<string, object>();
            try
            {
                var document = GetCollection(collection).Find
                               (filterBSON ?? new BsonDocument()).First();

                foreach (var element in document.Elements)
                {
                    if (element.Name != "_id")
                        result.Add(element.Name, element.Value);
                }
            }
            catch (Exception exception)
            {
                File.AppendAllText("Logs.txt", DateTime.Now.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);
                File.AppendAllText("Logs.txt", exception.ToString());
                File.AppendAllText("Logs.txt", Environment.NewLine);

                throw new Exception("Database exception.");
            }
            return result;
        }
    }
}

Points of Interest

I hope to update this code with other classifiers as and when I read and understand them.

History

  • Initial: k-NN classifier

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)