K-nearest Neighbor Classification

Using the k-nearest neighbor algorithm to cluster data

This tutorial explores the use of the k-nearest neighbor algorithm to classify data. The k-nearest neighbor (KNN) method is one of the simplest non-parametric techniques for classification and regression. An observation is classified by a majority vote of its neighbors, with the observation being assigned to the class most common amongst its K nearest neighbors as measured by a distance function. This tutorial examines how to:

  1. Load data from a dataset using loadd
  2. Fit a k-nearest neighbor model to a dataset using knnClassifyFit.
  3. Predict classification for a testing feature set using knnClassifyPredict.
  4. Plot classified data using plotClasses.

Load data

The data for this tutorial is stored in the file iris.csv. This file houses the Iris Flower Dataset first introduced in 1936 by Robert Fischer. This dataset includes observations of four features: the length and the width of the sepals and petals. In addition, the dataset includes observations from three separate species of Iris, setosa, virginica, and versicolor. This tutorial uses loadd to load the dataset into GAUSS prior to fitting the model. The function loadd uses the GAUSS formula string format which allows for loading and transforming of data in a single line. Detailed information on using formula string is available in the formula string tutorials.

The formula string syntax in loadd uses two specifications:

  1. The dataset specification.
  2. A formula that specifies how to load the data. This is optional if the complete dataset is to be loaded.

In the first step, we will load all the features except the species into a matrix.

new;
cls;
library gml;

//Load hitters dataset
x = loadd(getGAUSSHome $+ "pkgs/gml/examples/iris.csv", ". -Species");

Next, we will load the species data, stored as strings, into a string array using csvReadSA. In this case, we will give the csvReadSA function three inputs, the dataset name, the row number to start loading at and the column number to load:

//Load string labels
species = csvReadSA(fname, 2, 5);

Construct training and testing subsets

The GAUSS Machine Learning application includes the trainTestSplit function for splitting full datasets into randomly drawn train and test subsets. The function is fully compatible with the GAUSS formula string format. This allows subsets to be created without loading the full dataset. In addition, the formula string syntax allows for loading and transforming of data in a single line. Detailed information on using formula string is available in the formula string tutorials.

Using the matrix syntax in trainTestSplit requires the specification of three inputs:

  1. The predictor data matrix.
  2. The response or target data matrix.
  3. The proportion of data to include in the training dataset.

The trainTestSplit returns four outputs : y_train, x_train, y_test, and x_test. These outputs contain the feature and predictors, respectively, for the training and testing datasets, respectively:

//Split data set
{ x_train, x_test, y_train, y_test } = trainTestSplit(X, species, 0.7);

Fitting the knn model

The knn model is fit using the GAUSS procedure knnClassifyFit. The knnClassifyFit procedure takes three required inputs, a training feature matrix, the corresponding training target labels, and the number of neighbors.

The knnClassifyFit returns all output to a knnModel structure. An instance of the knnModel structure must be declared prior to calling knnClassifyFit. Each instance of the knnModel structure contains the following members:

Member Description
opaqueModel Column vector, containing the kd-tree in opaque form.
classIndices Px1 matrix, where P is the number of classes in the target vector 'y'.
classNames Px1 string aray, where P is the number of classes in the target vector 'y', containing the class names if the target vector was a string array.
k Scalar, the number of neighbors to search.

The code below uses the knn model to classify to the data matrix, x_train using the target matrix, y_train:

//Specify number of neighbors
k = 3;

//Declare the knnModel structure
struct knnModel mdl;

//Call knnClassifyFit
mdl = knnClassifyFit(y_train, x_train, k);

Make predictions

The knnClassifyPredict function is used after knnClassifyFit to make predictions from the knn classification model. The function requires a filled knnModel structure and test set of predictors. The code below computes the predictions and prints the prediction accuracy:

//Predict classes
y_hat = knnClassifyPredict(mdl, X_test);

print "prediction accuracy = " meanc(y_hat .$== y_test);

Plotting the assigned classes

The GAUSS plotClasses function provides a convenient tool for plotting the assigned clusters. The plotClasses function produces a 2-D scatter plot of the data matrix with each class plotted in a different color. The procedure requires two inputs, a 2-dimensional data vector, x, and a vector of class labels, labels. The label vector may be either a string array or numeric vector. Finally, the plot can be formatted by including an optional plotControl structure.
To start, let's set-up the plotControl to add a title to our graph and to turn the grid on the plot off. This is done in four steps:

  1. Declare an instance of the plotControl structure.
  2. Fill the structure with the defaults settings for a scatter plot using plotGetDefaults
  3. Use plotSetTitle to specify, the wording, font, and font color for the graph title.
  4. Use plotSetGrid to turn grid off.
//Declare plotControl structure
struct plotControl myPlot;

//Set up title
plotSetTitle(&myPlot, "Knn Classification", "Arial", 16, "Black");

//Set up axis labels
plotSetXlabel(&myPlot, "Petal Length", "Arial", 14, "Black");
plotSetYlabel(&myPlot, "Petal Width", "Arial", 14, "Black");

//Turn grid off
plotSetGrid(&myPlot, "off");

Next, we will plot the class assignments found using knnClassifyFit, y_hat, using plotClasses against the Petal length and width:

//Step Four: Plot results
plotClasses(x_test[.,3:4], y_hat, myPlot );

knn_iris

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.

Try GAUSS for 14 days for FREE

See what GAUSS can do for your data

© Aptech Systems, Inc. All rights reserved.

Privacy Policy