Random Forest Salary Prediction Code

new;
library gml;

/*
** Load and transform data
*/

// Load hitters dataset
dataset = getGAUSSHome $+ "pkgs/gml/examples/hitters.xlsx";

// Load salary and perform natural log transform
y = loadd(dataset, "ln(salary)");

// Load all variables except 'salary'
X = loadd(dataset, ". - salary");

/*
** Split into test and training sets
*/

// Set seed for repeatable sampling
rndseed 234234;

// Split data into training and test sets
{ y_train, y_test, X_train, X_test } = trainTestSplit(y, X, 0.7);

/*
** Estimate random forest model
*/

// Declare 'dfc' to be a dfControl structure
// and fill with default settings.
struct dfControl dfc;
dfc = dfControlCreate();

// Turn on variable importance
dfc.variableImportanceMethod = 1;

// Turn on OOB error
dfc.oobError = 1;

// Structure to hold model results
struct dfModel mdl;

// Fit training data using random forest
mdl = decForestRFit(y_train, X_train, dfc);

// OOB Error
print "Out-of-bag error:" mdl.oobError;

/*
** Plot variable importance
*/

// Load variable names from dataset
// and assign to dfModel structure
mdl.varNames = getHeaders(dataset);

// Draw variable importance plot
plotVariableImportance(mdl);

/*
** Predictions
*/

// Make predictions using test data
predictions = decForestPredict(mdl, X_test);

// Print predictions and decision forest test MSE
print predictions[1:5,.]~y_test[1:5,.];
print "";
print "random forest test MSE:" meanc((predictions - y_test).^2);

// Print ols test MSE
b_hat = y_train / (ones(rows(X_train), 1)~X_train);

alpha_hat = b_hat[1];
b_hat = trimr(b_hat, 1, 0);
y_hat = alpha_hat + X_test * b_hat;
print "OLS test MSE          :" meanc((y_hat - y_test).^2);

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.

Try GAUSS for 14 days for FREE

See what GAUSS can do for your data

© Aptech Systems, Inc. All rights reserved.

Privacy Policy