This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import game | |
import pickle | |
import numpy as np | |
PCA_MODEL = pickle.load(open('model/pca.pckl','rb')) | |
LDA_MODEL = pickle.load(open('model/lda.pckl','rb')) | |
ReducedObsSpacePCA = PCA_MODEL.n_components_ | |
ReducedObsSpaceLDA = len(LDA_MODEL.explained_variance_ratio_) | |
class State: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate_reward(self,game,action): | |
init_points = game.points | |
if ((len(game.dice) > 0) & (len([i for i in game.board if i.isnumeric()]) > 0)) | (game.jokers > 0): | |
if game.board[action].isnumeric(): | |
next_die = game.dice[0] | |
game.place_die(next_die,action+1) | |
elif len([i for i in game.board if i.isnumeric()]) > 0: | |
next_die = game.dice[0] | |
adjusted_pos = min([int(i) for i in game.board if i.isnumeric()],key = lambda x: abs(x-(action+1))) | |
game.place_die(next_die,adjusted_pos) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _train(self): | |
batch = np.asarray(random.sample(self.memory, BATCH_SIZE)) | |
if len(batch) < BATCH_SIZE: | |
return | |
current_states = [] | |
q_values = [] | |
max_q_values = [] | |
for entry in batch: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
expand_grid = lambda d: pd.DataFrame([row for row in product(*d.values())],columns=d.keys()) | |
y = pd.Series(self.y) | |
y_train = y.values[:-test_length] | |
y_test = y.values[-test_length:] | |
dates = pd.to_datetime(self.current_dates) | |
scores = [] # lower is better | |
if y.min() > 0: | |
grid = expand_grid({ | |
'trend':[None,'add','mul'], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Forecaster import Forecaster | |
forecasts = {} | |
for c in df.columns: | |
f = Forecaster(y=df[c].to_list(),current_dates=df.index.to_list(),name=c) | |
f.generate_future_dates(24,'MS') # MS is from pandas, meaning month start | |
forecasts[c] = f |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Forecaster import Forecaster | |
forecasts = {} | |
test_length = 3 | |
for sym in ('UTPHCI','FLPHCI'): | |
f = Forecaster() | |
f.get_data_fred(sym) | |
f.generate_future_dates(24,'MS') | |
f.forecast_auto_arima(test_length=test_length) | |
f.forecast_ets(test_length=test_length) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def visualize_results(): | |
def on_button_clicked(b): | |
ts_selection = ts_dropdown.value | |
mo_selection = list(mo_select.value) | |
with output: | |
clear_output() | |
forecasts[ts_selection].plot(models=mo_selection, print_mapes=True, plot_fitted=True) | |
models = ('auto_arima','tbats','ets','average') | |
ts_dropdown = widgets.Dropdown(options=forecasts.keys(), description = 'Time Series:') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def accuracy_trend_vis(): | |
""" visualize the model error metrics | |
leverages Jupyter widgets | |
""" | |
def display_user_selections(ts_selection): | |
""" displays graphs with seaborn based on what user selects from dropdown menus | |
""" | |
f = forecasts[ts_selection] | |
k = Counter(f.mape) | |
mapes = [h[1] for h in k.most_common()] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data_state.fillna(method='ffill',inplace=True) | |
max_perc = lambda x: np.max([i/100 for i in x]) | |
data_state_piv = pd.pivot_table(data_state, | |
index='location', | |
values=['total_vaccinations_per_hundred','people_vaccinated','people_fully_vaccinated_per_hundred','people_fully_vaccinated'], | |
aggfunc={'total_vaccinations_per_hundred':max_perc, | |
'people_fully_vaccinated_per_hundred':max_perc, | |
'people_vaccinated':np.max, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sns.set(rc={'figure.figsize':(14,8)}) | |
ax3 = sns.lineplot(x=data_cumsum_doses.index,y='total_distributed',data=data_cumsum_doses,color='purple',linestyle='--',label='Total Distributed') | |
sns.lineplot(x=data_cumsum_doses.index,y='people_vaccinated',data=data_cumsum_doses,color='red',label='Received at Least One Dose') | |
sns.lineplot(x=data_cumsum_doses.index,y='people_fully_vaccinated',data=data_cumsum_doses,color='orange',label='Fully Vaccinated') | |
above_val = 5000000 | |
plt.text(data_cumsum_doses.index.to_list()[-1], data_cumsum_doses['total_distributed'].max() + above_val, '{:,.0f}'.format(data_cumsum_doses['total_distributed'].max()),size=14,color = 'purple') | |
plt.text(data_cumsum_doses.index.to_list()[-1], data_cumsum_doses['people_vaccinated'].max() + above_val, '{:,.0f}'.format(data_cumsum_doses['people_vaccinated'].max()),size=14,color = 'red') | |
plt.text(data_cumsum_doses.index.to_list()[-1], data_cumsum_doses['people_fully_vaccinated'].max() + above_val, '{:,.0f}'.format(data_cumsum_doses['people_fully_vaccina |
OlderNewer