Using random forests to predict salary


This tutorial explores the use of random forests (also called decision forests) to predict baseball players' salaries.

The example builds on the examples in Chapter 8 of G. James, et al. (2013). The model will include 16 predictors: AtBat, Hits, HmRun, Runs, RBI, Walks, Years, CAtBat, CHits, CHmRun, CRuns, CRBI, CWalks, PutOuts, Assists, Errors.

This tutorial shows how to:

  1. Load the data and transform the target variable.
  2. Use trainTestSplit to split a dataset into random training and testing subsets.
  3. Specify parameters for random forest models using the dfControl structure.
  4. Fit a random forest regression model from training data using decForestRFit.
  5. Plot variable importance using plotVariableImportance.
  6. Use decForestPredict to make predictions from a random forest model.

Load the Data

The data for this tutorial is stored in the file islr_hitters.xlsx. The model will use the natural log of salary as the response variable and the 16 previously mentioned variables as predictors:

We will use loadd with GAUSS's formula string syntax, which allows for loading and transforming data in a single line.

library gml;

// Load hitters dataset
dataset = getGAUSSHome $+ "pkgs/gml/examples/hitters.xlsx";

// Load salary and perform natural log transform
y = loadd(dataset, "ln(salary)");

// Load all variables except 'salary'
X = loadd(dataset, ". - salary");

Construct Training and Test Sets

The GAUSS Machine Learning (GML) module includes the trainTestSplit function for splitting full datasets into randomly drawn train and test subsets.

trainTestSplit requires three inputs when using matrix inputs:

The dependent variable.
The independent variables.
The proportion of data to include in the training dataset.

The procedure trainTestSplit returns four outputs : y_train, X_train, y_test, and X_test. These outputs contain the feature and predictors, for the training and test datasets:

// Set seed for repeatable sampling
rndseed 234234;

// Split data into training and test sets
{ y_train, y_test, X_train, X_test } = trainTestSplit(y, X, 0.7);

Specify Model Parameters

Since random forest is a copyrighted term, the GAUSS functions use the term decision forest. All of the structures and functions use the prefix, df.

The decision forest model parameters are specified using the dfControl structure, which contains the following members:

Member Description
numTrees Scalar, number of trees (must be integer). Default = 100
obsPerTree Scalar, observations per a tree. Default = 1.0.
featuresPerNode Scalar, number of features considered at a node. Default = nvars/3.
maxTreeDepth Scalar, maximum tree depth. Default = unlimited.
minObsNode Scalar, minimum observations per node. Default = 1.
oobError Scalar, 1 to compute OOB error, 0 otherwise. Default = 0.
variableImpurityMethod Scalar, method of calculating variable importance.
0 = none,
1 = mean decrease in impurity,
2 = mean decrease in accuracy (MDA),
3 = scaled MDA. Default = 0.

Using the dfControl structure to change the model parameter requires three steps:

  1. Declare an instance of the dfControl structure
    struct dfControl dfc;
  2. Fill the members in the dfControl structure with default values using dfControlCreate:
    dfc = dfControlCreate();
  3. Change the desired members from their default values:
    dfc.oobError = 1

    The code below puts these three steps together to turn on both the out-of-bag error and variable importance computation:

/ Declare 'dfc' to be a dfControl structure
// and fill with default settings.
struct dfControl dfc;
dfc = dfControlCreate();

// Turn on variable importance
dfc.variableImportanceMethod = 1;

// Turn on OOB error
dfc.oobError = 1;

Fitting the Random Forest Regression Model

Random forest (or decision forest) regression models are fit using the GAUSS procedure decForestRfit.

The decForestRfit procedure takes two required inputs, the training response vector, and the training predictors matrix. In addition, the dfControl structure may be optionally included to specify model parameters.

The decForestRfit procedure returns all output to a dfModel structure. An instance of the dfModel structure must be declared prior to calling decForestRfit. Each instance of the dfModel structure contains the following members:

Member Description
variableImportance Matrix, 1 x p, variable importance measure if the computation of variable importance is specified, zero otherwise.
oobError Scalar, out-of-bag error if OOB error computation is specified, zero otherwise.
numClasses Scalar, number of classes if classification model, zero otherwise.
opaqueModel Matrix, contains model details for internal use only.

The code below fits the random forest model to the training data, y_train and X_train, which were generated earlier using trainTestSplit. In addition, the inclusion of dfc, the instance of the previously created dfControl structure, results in the computation of both the out-of-bag error and the variable importance.

// Structure to hold model results
struct dfModel mdl;

// Fit training data using random forest
mdl = decForestRFit(y_train, X_train, dfc);

// OOB Error
print "Out-of-bag error:" mdl.oobError;

The output from the code above:

Out-of-bag error:      0.22886297 

Plotting Variable Importance

A useful aspect of the random forest model is the variable importance measure. This measure provides a tool for understanding the relative importance of each predictor in the model. The procedure plotVariableImportance plots a pre-formatted bar graph of the variable importance.

The procedure takes a dfModel structure as its only input.

** Plot variable importance

// Load variable names from dataset
// and assign to dfModel structure
mdl.varNames = getHeaders(dataset);

// Draw variable importance plot

Random forest variable importance plot.

Make Predictions

The decForestPredict function is used after decForestRFit to make predictions from the random forest regression model. The function requires a filled dfModel structure and a test set of predictors. The code below computes the predictions, prints the first 10 predictions and finds and compares the Random Forest MSE to OLS MSE:

// Make predictions using test data
predictions = decForestPredict(mdl, X_test);

// Print predictions and decision forest test MSE
print predictions[1:5,.]~y_test[1:5,.];
print "";
print "random forest test MSE: " meanc((predictions - y_test).^2);

// Print ols test MSE
b_hat = y_train / (ones(rows(X_train), 1)~X_train);

alpha_hat = b_hat[1];
b_hat = trimr(b_hat, 1, 0);
y_hat = alpha_hat + X_test * b_hat;
print "OLS test MSE          :" meanc((y_hat - y_test).^2);

The output:

       6.8781181        6.0402547
       6.3337887        6.3630281
       6.0366998        5.7838252
       6.2387501        6.6200732
       5.5766902        5.2522734

random forest test MSE:      0.23044959
OLS test MSE          :      0.47283085  


In this post, you've seen how:

  1. Load and transform data.
  2. Create randomly sampled test and training sets.
  3. Estimate a basic random forest model.
  4. Draw a variable importance plot.

With just the default settings we were able to significantly outperform the OLS model.

Find the full code for this example here

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.