Introduction
This tip is a very simple implementation of k-Nearest Neighbor classifier in Machine Learning.
Background
I am currently watching YouTube playlist by Google Developers channel about machine learning. This tip is my attempt to understand what is taught in the playlist. The playlist can be found here if you are interested.
The Algorithm
k-NN algorithm is a pattern recognition algorithm. Here, k number of nearest neighbors are identified and based on the values of these neighbors, prediction is made. The nearest neighbors are identified using Euclidean distance.
For instance, let's say we have records containing petal lengths and widths for different flowers. Now, based on this data, we can try to predict the type of flower based on petal length and width. Here, length and width are referred to as features. A feature can be thought of as a coordinate. Now, when user provides a length and a width, we find k (user input) number of nearest coordinates and then based on flowers these neighbors represent, we predict the flower.
In the image below, if K were to be 3, and we had to predict which flower is represented by the point in black, we find three nearest points. These points will be A, B and C. Now, since A and B are I. versicolor (majority), we would predict the point in black to be I. versicolor.
I was having a discussion with a friend of mine and he told me that this works better with increase in number of features. However, with increase in number of features, it also becomes tougher for us to visualize. I was told that this algorithm is preferred while doing imputation in the real world.
The Classifier
The classifier is a simple class file that implements the algorithm. It does perform the following:
- Save the feature and value names
- Save the training data - initial data load
- Perform the prediction
- Save the new prediction to database if it was correct
It works with MongoDB with 2 types of documents. One will have names of features and value fields represented by the following JSON:
{ "0" : "Sepal Length", "1" : "Sepal Width", "2" :
"Petal Length", "3" : "Petal Width", "Value" : "Species" }
There will be only one document of this kind in the database.
Another type of document will be the actual data which is stored in the database. Since I am fetching all the documents from the collection, the _id
field of Mongo is left auto-generated.
{ "0" : 5.0, "1" : 3.6, "2" : 1.4, "3" : 0.3, "v" : 0.0 }
The order of data in these documents is the same as the order in features document.
The code is fairly trivial IMHO so I assume a code dump alone would suffice.
using MongoDB.Bson;
using MongoDB.Driver;
using MyLittleLearner.Helpers;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Linq;
namespace MyLittleLearner.Classifiers
{
internal class KNN
{
internal void SaveDataStore(double[,] data)
{
try
{
var collection = MongoHelper.GetCollection
(ConfigurationManager.AppSettings["DataStoreCollection"]);
collection.DeleteMany(new BsonDocument());
var rowCount = data.GetLength(0);
var colCount = data.GetLength(1);
for (var row = 0; row < rowCount; row++)
{
var valuesForBSON = new Dictionary<string, object>();
for (var column = 0; column < colCount - 1; column++)
{
valuesForBSON.Add(column.ToString(), data[row, column]);
}
valuesForBSON.Add("v", data[row, colCount - 1]);
var document = new BsonDocument(valuesForBSON);
collection.InsertOne(document);
}
}
catch (Exception exception)
{
File.AppendAllText("Logs.txt", DateTime.Now.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
File.AppendAllText("Logs.txt", exception.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
throw new Exception("Database exception.");
}
}
internal void SaveFeatureNames(string[] features, string value)
{
try
{
var collection = MongoHelper.GetCollection
(ConfigurationManager.AppSettings["FeatureNameCollection"]);
collection.DeleteMany(new BsonDocument());
var valuesForBSON = new Dictionary<string, object>();
for (int index = 0; index < features.Length; index++)
{
valuesForBSON.Add(index.ToString(), features[index]);
}
valuesForBSON.Add("Value", value);
var document = new BsonDocument(valuesForBSON);
collection.InsertOne(document);
}
catch (Exception exception)
{
File.AppendAllText("Logs.txt", DateTime.Now.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
File.AppendAllText("Logs.txt", exception.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
throw new Exception("Database exception.");
}
}
internal double Predict(double[] features, int k, out bool isExactMatch)
{
double returnValue = 0;
try
{
isExactMatch = false;
var collection = MongoHelper.GetCollection
(ConfigurationManager.AppSettings["DataStoreCollection"]);
BsonDocument filterBSON = new BsonDocument();
for (int index = 0; index < features.Length; index++)
{
filterBSON.Add(index.ToString(), features[index]);
}
var exactMatch = collection.Find(filterBSON).FirstOrDefault();
if (exactMatch == null)
{
var documents = collection.Find(new BsonDocument()).ToEnumerable();
double currentDistance = 0;
var currentPrediction = new SortedDictionary<double, double>();
double minDistance = 0;
foreach (var currentItem in documents.Take(k))
{
currentDistance = CalculateDistance(currentItem, features);
currentPrediction.Add
(currentDistance, currentItem.GetElement("v").Value.AsDouble);
}
minDistance = currentPrediction.Keys.Last();
foreach (var currentItem in documents.Skip(k))
{
currentDistance = CalculateDistance(currentItem, features);
if (minDistance > currentDistance)
{
currentPrediction.Remove(currentPrediction.Last().Key);
currentPrediction.Add(currentDistance,
currentItem.GetElement("v").Value.AsDouble);
minDistance = currentPrediction.Keys.Last();
}
}
returnValue =
currentPrediction.Values.GroupBy(x => x)
.OrderByDescending(x => x.Count())
.Select(x => x.Key)
.First();
}
else
{
isExactMatch = true;
returnValue = exactMatch.GetElement("v").Value.AsDouble;
}
}
catch (Exception exception)
{
File.AppendAllText("Logs.txt", DateTime.Now.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
File.AppendAllText("Logs.txt", exception.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
throw new Exception("Database exception.");
}
return returnValue;
}
private double CalculateDistance(BsonDocument currentItem, double[] features)
{
double distance = 0;
try
{
for (int index = 0; index < features.Length; index++)
{
distance = distance +
Math.Pow((currentItem.GetElement(index + 1).Value.AsDouble -
features[index]), 2);
}
}
catch (Exception exception)
{
File.AppendAllText("Logs.txt", DateTime.Now.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
File.AppendAllText("Logs.txt", exception.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
throw new Exception("Unable to calculate distance.");
}
return Math.Pow(distance, 0.5);
}
internal void Confirm(double[] features, double value)
{
try
{
var collection = MongoHelper.GetCollection
(ConfigurationManager.AppSettings["DataStoreCollection"]);
var valuesForBSON = new Dictionary<string, object>();
for (int index = 0; index < features.Length; index++)
{
valuesForBSON.Add(index.ToString(), features[index]);
}
valuesForBSON.Add("v", value);
var document = new BsonDocument(valuesForBSON);
collection.InsertOne(document);
}
catch (Exception exception)
{
File.AppendAllText("Logs.txt", DateTime.Now.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
File.AppendAllText("Logs.txt", exception.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
throw new Exception("Database exception.");
}
}
}
}
You might have noticed the class MongoHelper
. It is a simple class that provides a method to get the collection object from database and another one to get the first document with given filter. Here's all the code in MongoHelper
.
using MongoDB.Bson;
using MongoDB.Driver;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Linq;
namespace MyLittleLearner.Helpers
{
class MongoHelper
{
internal static IMongoCollection<BsonDocument> GetCollection(string collectionName)
{
var client = new MongoClient(ConfigurationManager.AppSettings["URL"]);
var database = client.GetDatabase(ConfigurationManager.AppSettings["Database"]);
var collection = database.GetCollection<BsonDocument>(collectionName);
return collection;
}
internal static Dictionary<string, object>
FetchFirstDocument(string collection, BsonDocument filterBSON)
{
Dictionary<string, object> result = new Dictionary<string, object>();
try
{
var document = GetCollection(collection).Find
(filterBSON ?? new BsonDocument()).First();
foreach (var element in document.Elements)
{
if (element.Name != "_id")
result.Add(element.Name, element.Value);
}
}
catch (Exception exception)
{
File.AppendAllText("Logs.txt", DateTime.Now.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
File.AppendAllText("Logs.txt", exception.ToString());
File.AppendAllText("Logs.txt", Environment.NewLine);
throw new Exception("Database exception.");
}
return result;
}
}
}
Points of Interest
I hope to update this code with other classifiers as and when I read and understand them.
History