Understanding Cross-Validation

Introduction

If you've explored machine learning models, you've probably come across the term "cross-validation" at some point. But what exactly is it, and why is it important?

In this blog, we'll break cross-validation into simple terms. With a practical demonstration, we'll equip you with the knowledge to confidently use cross-validation in your machine learning projects.

Model Validation in Machine Learning

Model validation and cross validation using testing and training datasets for machine learning models.

Machine learning validation methods provide a means for us to estimate generalization error. This is crucial for determining what model provides the most best predictions for unobserved data.

In cases where large amounts of data are available, machine learning data validation begins with splitting the data into three separate datasets:

  • A training set is used to train the machine learning model(s) during development.
  • A validation set is used to estimate the generalization error of the model created from the training set for the purpose of model selection.
  • A test set is used to estimate the generalization error of the final model.

Cross-Validation in Machine Learning

The model validation process in the previous section works when we have large datasets. When data is limited we must instead use a technique called cross-validation.

The purpose of cross-validation is to provide a better estimate of a model's ability to perform on unseen data. It provides an unbiased estimate of the generalization error, especially in the case of limited data.

There are many reasons we may want to do this:

  • To have a clearer measure of how our model performs.
  • To tune hyperparameters.
  • To make model selections.

The intuition behind cross-validation is simple - rather than training our models on one training set we train our model on multiple subsets of data.

The basic steps of cross-validation are:

  1. Split data into portions.
  2. Train our model on a subset of the portions.
  3. Test our model on the remaining subsets of the data.
  4. Repeat steps 2-3 until the model has been trained and tested on the entire dataset.
  5. Average the model performance across all iterations of testing to get the total model performance.

Common Cross-Validation Methods

Though the basic concept of cross-validation is fairly simple, there are a number of ways to go about each step. A few examples of cross-validation methods include

  1. k-Fold Cross-Validation
    In k-fold cross-validation:

    • The dataset is divided into k equal sized-folds.
    • The model is trained on k-1 folds and tested on the remaining fold.
    • The process is repeated k times, with each fold serving as the test set exactly once.
    • The performance metrics are averaged over the k iterations.
  2. Stratified k-Fold Cross-Validation
    This process is similar to k-fold cross-validation with minor but important exceptions:

    • The class distribution in each fold is preserved.
    • It is useful for imbalanced datasets.
  3. Leave-One-Out Cross-Validation
    The Leave-one-out cross-validation process:

    • Trains the model using all data observations except one.
    • Tests the data using the unused data point.
    • Repeats this for n iterations until each data point is used exactly once as a test set.
  4. Time-Series Cross-Validation
    This cross-validation method, designed specifically for time-series:
    • Splits the data into training and testing sets in a chronologically ordered manner, such as sliding or expanding windows.
    • Trains the model on past data and tests the model on future data, based on the splitting point.
MethodAdvantagesDisdvantages
k-Fold Cross-Validation
  • Provides a good estimate of the model's performance by using all the data for both training and testing.
  • Reduces the variance in performance estimates compared to other methods.
  • Can be computationally expensive, especially for large datasets or complex models.
  • May not work well for imbalanced datasets or when there is a specific order to the data.
Stratified k-Fold Cross-Validation
  • Ensures that each fold has a representative distribution of classes, which can improve performance estimates for imbalanced datasets.
  • Reduces the variance in performance estimates compared to other methods.
  • Can still be computationally expensive, especially for large datasets or complex models.
  • May not be necessary for balanced datasets where class distribution is already even.
Leave-One-Out Cross-Validation (LOOCV)
  • Provides the least biased estimate of the model's performance, as the model is tested on every data point.
  • Can be useful when dealing with very limited data.
  • Can be computationally expensive, as it requires training and testing the model n times.
  • May have high variance in performance estimates, due to the small size in the test set.
Time Series Cross-Validation
  • Accounts for temporal dependencies in time series data.
  • Provides a realistic estimate of the model's performance in real-world scenarios.
  • May not be applicable for non-time series data.
  • Can be sensitive to the choice of window size and data splitting strategy.

k-Fold Cross-Validation Example

Let's look at k-fold cross-validation in action, using the wine quality dataset included in the GAUSS Machine Learning (GML) library. This file is based on the Kaggle Wine Quality dataset.

Our objective is to classify wines into quality categories using 11 qualities:

  • Fixed acidity.
  • Volatile acidity.
  • Citric acid.
  • Residual sugar.
  • Chlorides.
  • Free sulfur dioxide.
  • Total sulfur dioxide.
  • Density.
  • pH.
  • Sulphates.
  • Alcohol.

We'll use k-fold cross-validation to examine the performance of a random forest classification model.

Data Loading and Organization

First we will load our data directly from the GML library:

/*
** Load data and prepare data
*/
// Filename
fname = getGAUSSHome("pkgs/gml/examples/winequality.csv");

// Load wine quality dataset
dataset = loadd(fname);

After loading the data, we need to shuffle the data and extract our dependent and independent variables.

// Enable repeatable sampling
rndseed 754931;

// Shuffle the dataset (sample without replacement),
// because cvSplit does not shuffle.
dataset = sampleData(dataset, rows(dataset));

y = dataset[.,"quality"];
X = delcols(dataset, "quality");

Setting Random Forest Hyperparameters

After loading our data, we will set the random forest hyperparameters using the dfControl structure.

// Enable GML library functions
library gml;

/*
** Model settings
*/
// The dfModel structure holds the trained model
struct dfModel dfm;

// Declare 'dfc' to be a dfControl
// structure and fill with default settings
struct dfControl dfc;
dfc = dfControlCreate();

// Create 200 decision trees
dfc.numTrees = 200;

// Stop splitting if impurity at
// a node is less than 0.15
dfc.impurityThreshold = 0.15;

// Only consider 2 features per split
dfc.featuresPerSplit = 2;

k-fold Cross-Validation

Now that we have loaded our data and set our hyperparameters, we are ready to fit our random forest model and implement k-fold cross-validation.

First we setup the number of folds and pre-allocate a storage vector for model accuracy.

// Specify number of folds
// This generally is 5-10
nfolds = 5;

// Pre-allocate vector to hold the results
accuracy = zeros(nfolds, 1);

Next we use a GAUSS for loop to complete four steps:

  1. Select testing and training data from our folds using the cvSplit procedure.
  2. Fit our random forest classification model on the chosen training data using decForestCFit procedure.
  3. Make classification predictions using the chosen testing data and the decForestPredict procedure.
  4. Compute and store model accuracy for each iteration.
for i(1, nfolds, 1);
    { y_train, y_test, X_train, X_test } = cvSplit(y, X, nfolds, i);

    // Fit model using this fold's training data
    dfm = decForestCFit(y_train, X_train, dfc);

    // Make predictions using this fold's test data
    predictions = decForestPredict(dfm, X_test);

    accuracy[i] = meanc(y_test .== predictions);
endfor;

Results

Let's print the accuracy results and the total model accuracy:

/*
** Print Results
*/
sprintf("%7s %10s", "Fold", "Accuracy");;
sprintf("%7d %10.2f", seqa(1,1,nfolds), accuracy);
sprintf("Total model accuracy           : %10.2f", meanc(accuracy));
sprintf("Accuracy variation across folds: %10.3f", stdc(accuracy));
   Fold   Accuracy
      1       0.70
      2       0.73
      3       0.65
      4       0.71
      5       0.71
Total model accuracy           :       0.70
Accuracy variation across folds:      0.028

Our results provide some important insights into why we conduct cross-validation:

  • The model accuracy is different across folds, with a standard deviation of 0.028.
  • The maximum accuracy, using fold 2, is 0.73.
  • The minimum accuracy, using folds 3 is 0.65.

Depending on how we split our testing and training, we could get a different picture of model performance.

The total model accuracy, at 0.70, gives a better overall measure of model performance. The standard deviation of the accuracy gives us some insight into how much our prediction accuracy might vary.

Conclusion

If you're looking to improve the accuracy and reliability of your statistical analysis, cross-validation is a crucial technique to learn. In today's blog we've provided a guide to getting started with cross-validation.

Our step-by-step practical demonstration using GAUSS should prepare you to confidently implement cross-validation in your own data analysis projects.

Further Machine Learning Reading

  1. Predicting Recessions with Machine Learning Techniques
  2. Applications of Principal Components Analysis in Finance
  3. Predicting The Output Gap With Machine Learning Regression Models
  4. Fundamentals of Tuning Machine Learning Hyperparameters
  5. Machine Learning With Real-World Data
  6. Classification with Regularized Logistic Regression
Leave a Reply

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.

Try GAUSS for 14 days for FREE

See what GAUSS can do for your data

© Aptech Systems, Inc. All rights reserved.

Privacy Policy