Random forest gdp prediction code

//GDP forecasting using Random Forests
library gml,tsmt;

Load data from GDP dataset
//GDP data quarterly
path = "C:/svn/apps/gml/examples/gdp_tutorial";
gdp_q = loadd(path $+ "/rf_gdp.dat", "rgdp_pc");

//Load all other features
features_q = loadd(path $+ "/rf_gdp.dat", ". -rgdp_pc" );

//Load variable names
vnames = getHeaders(path $+ "/rf_gdp.dat");

Split data for training and testing
Testing : 1961Q2 to 1999Q4
Training : 2000Q1 to 2017Q4
testT = 155; 

//Print date ranges
print "Start data of test data:" dttostr(features_q[1,1], "YYYY-QQ");
print "End date of test data:" dttostr(features_q[testT,1], "YYYY-QQ");
print "Start date of training data:" dttostr(features_q[testT+1,1], "YYYY-QQ");
print "End date of training data:" dttostr(features_q[rows(features_q),1], "YYYY-QQ");

Set up data for random forest predictions
For training we will use through observation
y_train = gdp_q[1:testT,.];
x_train = features_q[1:testT,2:45];
y_test = gdp_q[testT+1:rows(gdp_q),.];
x_test = features_q[testT+1:rows(features_q),2:45];

Set up parameters for fitting model
//Use control structure for settings
struct rfControl rfc;
rfc = rfControlCreate;

//Turn on variable importance
rfc.variableImportanceMethod = 2;

//Turn on OOB error
rfc.oobError = 1;

Fit model
//Output structure
struct rfModel out;

//Fit training data using random forest
out = rfRegressFit(y_train, x_train, rfc);

//OOB Error
print "Out-of-bag error:" out.oobError;

Plot variable importance
plotVariableImportance(out, vnames[3:46]);

//Make predictions using test data
predictions = rfRegressPredict(out, x_test);

//Print predictions
print predictions[1:10]~y_test[1:10];
print "random forest MSE: " meanc((predictions - y_test).^2);

////Print ols MSE
b_hat = y_train/(ones(rows(x_train), 1)~x_train);
y_hat = (ones(rows(x_test),1)~x_test) * b_hat;
print "OLS MSE using test data  : " meanc((y_hat - y_test).^2);
corrx( predictions~y_test);

Plot GDP data using plotTS
//Set Canvas Size

//start date
dtstart = features_q[1,1];

//Plot control structure
struct plotControl myPlot;
myPlot = plotGetDefaults("XY");

//Place first 'X' tic mark at 1984 month 1 and draw one every 6 months
plotSetXTicInterval(&myPlot, 20, 1961);

//Display only 4 digit year on 'X' tic labels
plotSetXTicLabel(&myPlot, "YYYY-QQ");

plotSetLegend(&myPlot, "Obs."$|"Predicted","TOP RIGHT", 1);

//Plot title
plotSetTitle(&myPlot, "U.S. Real GDP (Annual Percent Change, 2009=100)");
plotTS(myPlot, dtstart, 4, gdp_q*100);

//Add predictions
plotAddTS(features_q[testT+1,1], 4, predictions*100);

Have a Specific Question?

Get a real answer from a real person

Need Support?

Get help from our friendly experts.

Try GAUSS for 14 days for FREE

See what GAUSS can do for your data

© Aptech Systems, Inc. All rights reserved.

Privacy Policy