Created
July 24, 2020 02:42
-
-
Save jamm1985/5dfb95f9d052c9a1b1fde6399acf2122 to your computer and use it in GitHub Desktop.
Gets data from XLSX, does explanatory analysis and performs OLS regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
File: rainfall_regression.py | |
Author: Andrey Stepnov | |
Email: [email protected], [email protected] | |
Github: https://github.com/jamm1985 | |
Description: Gets data from XLSX, does explanatory analysis and performs OLS regression | |
""" | |
import matplotlib.pylab as plt | |
import pandas as pd | |
import seaborn as sns | |
import numpy as np | |
from sklearn.linear_model import LinearRegression | |
import statsmodels.api as sm | |
# read data print correlation coefficients | |
data = pd.read_excel('rainfall.xlsx') | |
print('correlation matrix...\n', data.corr()) | |
# first look! | |
data.plot(x='YEAR', y=['USS', 'GORY']) | |
plt.show() | |
# see how to OLS line with errors looks like | |
sns.regplot(data=data,x='USS', y='GORY') | |
plt.show() | |
## OLS with scikit-learn | |
# copy original data and clean from Na for OLS | |
data_selected = data[['USS', 'GORY']].copy() | |
data_selected = data_selected.dropna() | |
# Convert data to numpy and reshape X to 2d array | |
X_uss = data_selected['USS'].to_numpy() | |
X_uss = X_uss.reshape(len(X_uss),1) | |
Y_gory = data_selected['GORY'].to_numpy() | |
# do OLS and print R^2 | |
reg_uss_gory = LinearRegression().fit(X_uss, Y_gory) | |
print('R^2 = {}'.format(reg_uss_gory.score(X_uss, Y_gory))) | |
print('intercept = {}'.format(reg_uss_gory.intercept_)) | |
print('coefficients = {}'.format(reg_uss_gory.coef_)) | |
# stat models OLS with pretty good summary! (t-tests, f-tests, etc) | |
X_uss_sm = data_selected['USS'] | |
Y_gory_sm = data_selected['GORY'] | |
# add intercept | |
X_uss_sm = sm.add_constant(X_uss_sm) | |
reg_uss_gory_sm = sm.OLS(Y_gory_sm, X_uss_sm) | |
# fit and print summary | |
print(reg_uss_gory_sm.fit().summary()) | |
# predict GORY variable with scikit-learn OLS linear model | |
# and add to dataframe | |
data['GORY_PRED'] = np.round( | |
reg_uss_gory.predict(data['USS'].to_numpy().reshape(data['USS'].count(), 1))) | |
data = data[['YEAR', 'USS', 'GORY', 'GORY_PRED', 'LESNOYE']] | |
# plot results | |
data.plot(x='YEAR', y=['USS', 'GORY', 'GORY_PRED'], ) | |
plt.show() | |
data.to_excel('GORY_PRED.xlsx') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
rainfall.xlsx