Random Forest Salary Prediction Code

new;
library gml;

//Load hitters dataset
dataset = getGAUSSHome $+ "pkgs/gml/examples/islr_hitters.xlsx";

//Split data into training and test sets
{ y_train, y_test, x_train, x_test } = trainTestSplit(dataset, "ln(salary)~ 
AtBat + Hits + HmRun + Runs + RBI + Walks + Years + PutOuts + Assists + Errors", 0.7);

/**********************************
Set model parameters
**********************************/
//Use constrol structure for settings
struct rfControl rfc;
rfc = rfControlCreate;

//Turn on variable importance
rfc.variableImportanceMethod = 1;

//Turn on OOB error
rfc.oobError = 1;

/**********************************
Fit model
**********************************/
//Output structure
struct rfModel out;

//Fit training data using random forest
out = rfRegressFit(y_train, x_train, rfc);

//OOB Error
print "Out-of-bag error:" out.oobError;

/************************************
Variable importance plot
*************************************/
//Set up variable names
names = "AtBat"$|"Hits"$|"HmRun"$|"Runs"$|
"RBI"$|"Walks"$|"Years"$|"PutOuts"$|"Assists"$|"Errors";

//Plot variable names
plotVariableImportance(out,names);

/***********************************
Predictions
************************************/
//Make predictions using test data
predictions = rfRegressPredict(out, x_test);

//Print predictions
print predictions[1:10,.]~y_test[1:10,.]);
print "random forest MSE: " meanc((predictions - y_test).^2);

//Print ols MSE
b_hat = y_train / (ones(rows(x_train), 1)~x_train);
y_hat = (ones(rows(x_test),1)~x_test) * b_hat;
print "OLS MSE using test data  : " meanc((y_hat - y_test).^2);

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.

REQUEST A FREE QUOTE

Thank you for your interest in the GAUSS family of products.

© Aptech Systems, Inc. All rights reserved.

Privacy Policy | Sitemap