GAUSS rfRegressFit example

This example uses the red wine quality dataset from Cortez, et al., 2009 to fit a random forest regression model. The dataset contains 200 observations and includes 12 variables: fixed acidity, volatile acidity, citric acid, residual sugar, chlorides, free sulfur dioxide, total sulfur dioxide, density, pH, sulphates, alcohol, and quality.

Split the dataset

Prior to fitting the random forest model, the testTrainSplit function is used to split the model data into test and training sets. The testTrainSplitfunction is compatible with the GAUSS formula string syntax and creates the test and train datasets without loading the full dataset. In this regression model quality is the response variable and all other variables are used as the predictors:

// Specify dataset name with full path
dataset = getGAUSSHome() $+ "pkgs/gml/examples/winequality-red.csv";

// Split data into 70% training and 30% test sets
{ y_train, y_test, x_train, x_test } = testTrainSplit(dataset, "quality ~ .", 0.7);

Estimate The Model

The rfRegressFit function is used on the y_train and x_train matrices to fit a random forest regression model. All results are stored in a rfModel structure:

// Output structure
struct rfModel rfm;

// Fit training data using random forest
rfm = rfRegressFit(y_train, x_train);

Make predictions

Once the model is fit predictions can be made from the x_test dataset using rfRegressPredict function. The rfRegressPredict function requires two inputs, a rfModel structure and a data matrix of predictors:

// Make predictions using test data
predictions = rfRegressPredict(rfm, x_test);

// Print predictions
print predictions[1:10]~y_test[1:10];
print "random forest MSE: " meanc((predictions - y_test).^2);

// Print ols MSE
b_hat = y_train / (ones(rows(x_train), 1)~x_train);
y_hat = (ones(rows(x_test),1)~x_test) * b_hat;
print "OLS MSE          : " meanc((y_hat - y_test).^2);


The output from the code above looks similar to :

       5.0929643        5.0000000
       5.1308175        5.0000000
       5.1799206        5.0000000
       5.1720873        5.0000000
       5.5779881        7.0000000
       5.3214921        5.0000000
       5.3608810        5.0000000
       5.7493413        7.0000000
       5.2857103        5.0000000
       5.1103095        4.0000000

random forest MSE:       0.35596292
OLS MSE          :       0.42207614 

