Created
October 25, 2015 00:44
-
-
Save Akramz/8d9d40800c7cc4eb537f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import statsmodels.api as sm | |
import sys | |
# try whatever you want | |
element = 'hour' | |
df = pd.read_csv('../improved_data_set/turnstile_weather_v2.csv', index_col=0) | |
dk = pd.DataFrame(df.groupby([element])['ENTRIESn_hourly'].mean().reset_index()) | |
X = dk[[element]] | |
y = dk['ENTRIESn_hourly'] | |
## fit a OLS model with intercept on TV and Radio | |
X = sm.add_constant(X) | |
est = sm.OLS(y, X).fit() | |
print est.summary() | |
# DRAW | |
# We pick 100 hundred points equally spaced from the min to the max | |
X_prime = np.linspace(X[element].min(), X[element].max(), 100)[:, np.newaxis] | |
X_prime = sm.add_constant(X_prime) # add constant as we did before | |
# Now we calculate the predicted values | |
y_hat = est.predict(X_prime) | |
plt.scatter(X[element], y, alpha=0.3) # Plot the raw data | |
plt.xlabel(element) | |
plt.ylabel("Hourly Entries") | |
plt.plot(X_prime[:, 1], y_hat, 'r', alpha=0.9) # Add the regression line, colored in red | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment