Skip to content

Instantly share code, notes, and snippets.

@rigogsilva
Created October 14, 2021 18:55
Show Gist options
  • Save rigogsilva/23f1dc49d050f05f510287a0c0b7f0e7 to your computer and use it in GitHub Desktop.
Save rigogsilva/23f1dc49d050f05f510287a0c0b7f0e7 to your computer and use it in GitHub Desktop.
# cd.display.markdown("## Polynomial Regression")
# y = transformed_shifts['deletedEmployeeCount']
# x = transformed_shifts['deletedEmployeeCount']
# model = np.poly1d(np.polyfit(x, y, 20))
# line = np.linspace(
# transformed_shifts['deletedEmployeeCount'].min(),
# transformed_shifts['deletedEmployeeCount'].max(),
# 200
# )
# plt.scatter(x, y, color='grey')
# plt.plot(line, model(line))
# # plt.show()
# cd.display.pyplot()
#
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import statsmodels.formula.api as sm
def predict_using_ols_result_coefficients(x: pd.Series) -> pd.Series:
"""Returns the y predicted for equation y = a + b1x + b2*x"""
# result = (13.6890 + -0.1334*x + 0.0006*x + 0.4685*x)
result = (1.1294 + -0.6105*x + 0.0714*x)
return result
def gen_curvilinear_x_and_y():
""""""
rng = np.random.RandomState(1)
raw_x = 8 * rng.rand(50)
y = np.sin(raw_x) + 0.1 * rng.randn(50)
inds = raw_x.ravel().argsort() # Sort x values and get index
raw_x = raw_x.ravel()[inds]
y = y[inds] # Sort y according to x sorted index
return raw_x, y
x, y = gen_curvilinear_x_and_y()
df = pd.DataFrame({'x': x, 'y': y})
ols_model = sm.ols(
formula=f'y ~ x + np.power(x, 3)',
data=df
).fit()
print(ols_model.params)
print(ols_model.summary())
fig = go.Figure()
fig.add_trace(
go.Scatter(x=x, y=y, mode='markers')
)
line = np.linspace(
df['x'].values[0],
df['x'].values[-1],
200
)
fake_df = pd.DataFrame()
fake_df['x'] = line
fig.add_trace(
go.Scatter(x=line, y=ols_model.predict(fake_df), mode='lines')
)
fig.show()
import plotly.graph_objects as go
model = np.poly1d(np.polyfit(x, y, 3))
line = np.linspace(
df['x'].values[0],
df['x'].values[-1],
50
)
# plt.scatter(x, y, color='grey')
# plt.plot(line, model(line))
# plt.show()
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name='Data', marker={'color': 'grey'}, showlegend=True))
y_fit = model(line)
fig.add_trace(go.Scatter(x=line, y=y_fit, mode='lines', name='Fit', marker={'color': 'blue'}))
fig.update_layout({
'xaxis': {'showgrid': False, 'title': 'x'},
'yaxis': {'showgrid': False, 'title': 'y'},
'paper_bgcolor': 'white',
'plot_bgcolor': 'white'
})
# fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment