Skip to content

Instantly share code, notes, and snippets.

@jamm1985
Created July 24, 2020 02:42
Show Gist options
  • Save jamm1985/5dfb95f9d052c9a1b1fde6399acf2122 to your computer and use it in GitHub Desktop.
Save jamm1985/5dfb95f9d052c9a1b1fde6399acf2122 to your computer and use it in GitHub Desktop.
Gets data from XLSX, does explanatory analysis and performs OLS regression
"""
File: rainfall_regression.py
Author: Andrey Stepnov
Email: [email protected], [email protected]
Github: https://github.com/jamm1985
Description: Gets data from XLSX, does explanatory analysis and performs OLS regression
"""
import matplotlib.pylab as plt
import pandas as pd
import seaborn as sns
import numpy as np
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
# read data print correlation coefficients
data = pd.read_excel('rainfall.xlsx')
print('correlation matrix...\n', data.corr())
# first look!
data.plot(x='YEAR', y=['USS', 'GORY'])
plt.show()
# see how to OLS line with errors looks like
sns.regplot(data=data,x='USS', y='GORY')
plt.show()
## OLS with scikit-learn
# copy original data and clean from Na for OLS
data_selected = data[['USS', 'GORY']].copy()
data_selected = data_selected.dropna()
# Convert data to numpy and reshape X to 2d array
X_uss = data_selected['USS'].to_numpy()
X_uss = X_uss.reshape(len(X_uss),1)
Y_gory = data_selected['GORY'].to_numpy()
# do OLS and print R^2
reg_uss_gory = LinearRegression().fit(X_uss, Y_gory)
print('R^2 = {}'.format(reg_uss_gory.score(X_uss, Y_gory)))
print('intercept = {}'.format(reg_uss_gory.intercept_))
print('coefficients = {}'.format(reg_uss_gory.coef_))
# stat models OLS with pretty good summary! (t-tests, f-tests, etc)
X_uss_sm = data_selected['USS']
Y_gory_sm = data_selected['GORY']
# add intercept
X_uss_sm = sm.add_constant(X_uss_sm)
reg_uss_gory_sm = sm.OLS(Y_gory_sm, X_uss_sm)
# fit and print summary
print(reg_uss_gory_sm.fit().summary())
# predict GORY variable with scikit-learn OLS linear model
# and add to dataframe
data['GORY_PRED'] = np.round(
reg_uss_gory.predict(data['USS'].to_numpy().reshape(data['USS'].count(), 1)))
data = data[['YEAR', 'USS', 'GORY', 'GORY_PRED', 'LESNOYE']]
# plot results
data.plot(x='YEAR', y=['USS', 'GORY', 'GORY_PRED'], )
plt.show()
data.to_excel('GORY_PRED.xlsx')
@jamm1985
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment