Created
October 24, 2015 12:11
-
-
Save Akramz/8738ed30cdcb84105ce7 to your computer and use it in GitHub Desktop.
to do linear regression of tempi -> hourly entries
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import statsmodels.api as sm | |
df = pd.read_csv('../improved_data_set/turnstile_weather_v2.csv', index_col=0) | |
dk = pd.DataFrame(df.groupby('tempi')['ENTRIESn_hourly'].mean()) | |
dk['tempi'] = dk.index | |
y = dk.ENTRIESn_hourly # response | |
X = dk['tempi'] # predictor | |
X = sm.add_constant(X) # Adds a constant term to the predictor | |
est = sm.OLS(y, X) | |
est = est.fit() | |
print est.summary() | |
print est.params | |
# DRAW | |
# We pick 100 hundred points equally spaced from the min to the max | |
X_prime = np.linspace(X.tempi.min(), X.tempi.max(), 100)[:, np.newaxis] | |
X_prime = sm.add_constant(X_prime) # add constant as we did before | |
# Now we calculate the predicted values | |
y_hat = est.predict(X_prime) | |
plt.scatter(X.tempi, y, alpha=0.3) # Plot the raw data | |
plt.xlabel("temp") | |
plt.ylabel("Hourly Entries") | |
plt.plot(X_prime[:, 1], y_hat, 'r', alpha=0.9) # Add the regression line, colored in red | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment