K nearest neighbours classification algorithm- Working, Problems and solutions

Classification is one of the most important and basic tasks of an intelligent device. Machine learning has many ways of teaching a device to identify the class of an object. The k-nearest neighbour algorithm is probably one of the most simple, yet effective methods of classification. We shall see how this algorithm works, its main problems and the solutions.

The original algorithm

It uses the training data to have an idea of the region, which a class is concentrated in, and makes predictions accordingly.

A simple illustration of KNN

In the figure above, the training data set is divided into two classes- red and blue. The black point is a test data point whose class is to be determined. The key information that this algorithm uses are: training data and the parameter k. The algorithm looks for the training data points that are closest to the test data point. After the kth nearest point, it stops searching. Then, among the points traversed, we count the number of points belonging to each class. The class that emerges with maximum points in this counting process is then assigned to the point in question. In this example, k is 5 and the data point is classified as a blue point because among the 5 nearest 5 points, 3 are blue and 2 are red.

The following code will explain very simply, the working of the KNN algorithm with scikitlearn library of Python.

>>> from sklearn.neighbors import NearestNeighbors
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)
>>> distances, indices = nbrs.kneighbors(X)
>>> indices                                           
array([[0, 1],
       [1, 0],
       [2, 1],
       [3, 4],
       [4, 3],
       [5, 4]]...)
>>> distances
array([[ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.41421356],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.41421356]])

>>> nbrs.kneighbors_graph(X).toarray()
array([[ 1.,  1.,  0.,  0.,  0.,  0.],
       [ 1.,  1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  1.,  0.],
       [ 0.,  0.,  0.,  1.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  1.]])



No process is perfect. While the k-nearest neighbour was a revelation when it was formulated in 1951, it has its disadvantages. To begin with, the parameter k is selected by a rather complex process. If this process returns a value that is divisible by the total number of classes, then many points will not be classified. For example, if the optimal k for a given dataset is found to be 6 and it has 3 classes, then for a test data point, we can have 2 of the 6 nearest points belonging to each class. Classification through the mentioned algorithm is not possible in this case.

A major problem that has not been answered completely is the limitation imposed by k. As we can see from the above explanation, the main factor used to make predictions using kNN model is the distance between 2 points. If a point is closer to its neighbour then this neighbour has an effect on the point’s class. Given that this is the underlying principle behind the algorithm, the distance between two points is one of the most important features in this method. That is not always followed. Let us look at an example.

It is ironic that the parameter that makes this algorithm work (k) is the parameter that causes all these problems.

Here, for classifying the points A, B and C- we are taking 3 nearest neighbours. However, for C there is only one neighbour that should be taken into consideration. Due to the constraint imposed by k, we have to consider points that are far, far away. Further, the distance between y and B is a lot more than that between A and x. In spite of that, y has an impact on the classification of B while x has no bearing on A’s class.

This anomaly is one of the things that has led to countless varieties of kNN. They haven’t necessarily countered the problem altogether- owing to the fact that the term k cannot be easily removed from this algorithm because as mentioned earlier, it makes up half of our arsenal to make the classification. However, the results have been greatly improved by the following.


There are many versions of this algorithm. We are going to look at two that are very widely used and solve the problem that has been described above.

  • Local mean k-nearest neighbour:

    We take up an approach that is distance based. It was proposed in 2012 by Jianping Gou,Lan Du, Yuhong Zhang and Taisong Xiong. The algorithm spots the k-nearest neighbours of the test data point. It then takes the distance of each point from the test data. Then for each class, we take the mean of distances from the aforementioned test point. The class with the least mean distance is then assigned to our test data. This method not only makes the prediction based on distance but also incorporates the locality defined by traditional k-nearest neighbours.

  • Weighted k-nearest neighbour:

    This solves almost every problem with classification of a well distributed data set as far as results are concerned. In this method, we do not count the number of neighbours belonging to each class. We keep adding the weights assigned to each point based on its distance from the test data. The formula to assign weights can be as simple as reciprocal of the distance.

There are many more types of k-nearest neighbours and more are being worked on as classification problems remain one of the most crucial things in the fabrication of intelligent systems.

Don't miss out!
Subscribe To Our Newsletter

Learn new things. Get an article everyday.

Invalid email address
Give it a try. You can unsubscribe at any time.