Random Forest Salary Prediction Code

library gml;

//Load hitters dataset
dataset = getGAUSSHome $+ "pkgs/gml/examples/islr_hitters.xlsx";

//Split data into training and test sets
{ y_train, y_test, x_train, x_test } = trainTestSplit(dataset, "ln(salary)~ 
AtBat + Hits + HmRun + Runs + RBI + Walks + Years + PutOuts + Assists + Errors", 0.7);

Set model parameters
//Use constrol structure for settings
struct rfControl rfc;
rfc = rfControlCreate;

//Turn on variable importance
rfc.variableImportanceMethod = 1;

//Turn on OOB error
rfc.oobError = 1;

Fit model
//Output structure
struct rfModel out;

//Fit training data using random forest
out = rfRegressFit(y_train, x_train, rfc);

//OOB Error
print "Out-of-bag error:" out.oobError;

Variable importance plot
//Set up variable names
names = "AtBat"$|"Hits"$|"HmRun"$|"Runs"$|

//Plot variable names

//Make predictions using test data
predictions = rfRegressPredict(out, x_test);

//Print predictions
print predictions[1:10,.]~y_test[1:10,.]);
print "random forest MSE: " meanc((predictions - y_test).^2);

//Print ols MSE
b_hat = y_train / (ones(rows(x_train), 1)~x_train);
y_hat = (ones(rows(x_test),1)~x_test) * b_hat;
print "OLS MSE using test data  : " meanc((y_hat - y_test).^2);

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.

Try GAUSS for 30 days for FREE

See what GAUSS can do for your data

© Aptech Systems, Inc. All rights reserved.

Privacy Policy