This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import game | |
| import pickle | |
| import numpy as np | |
| PCA_MODEL = pickle.load(open('model/pca.pckl','rb')) | |
| LDA_MODEL = pickle.load(open('model/lda.pckl','rb')) | |
| ReducedObsSpacePCA = PCA_MODEL.n_components_ | |
| ReducedObsSpaceLDA = len(LDA_MODEL.explained_variance_ratio_) | |
| class State: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def evaluate_reward(self,game,action): | |
| init_points = game.points | |
| if ((len(game.dice) > 0) & (len([i for i in game.board if i.isnumeric()]) > 0)) | (game.jokers > 0): | |
| if game.board[action].isnumeric(): | |
| next_die = game.dice[0] | |
| game.place_die(next_die,action+1) | |
| elif len([i for i in game.board if i.isnumeric()]) > 0: | |
| next_die = game.dice[0] | |
| adjusted_pos = min([int(i) for i in game.board if i.isnumeric()],key = lambda x: abs(x-(action+1))) | |
| game.place_die(next_die,adjusted_pos) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def _train(self): | |
| batch = np.asarray(random.sample(self.memory, BATCH_SIZE)) | |
| if len(batch) < BATCH_SIZE: | |
| return | |
| current_states = [] | |
| q_values = [] | |
| max_q_values = [] | |
| for entry in batch: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| expand_grid = lambda d: pd.DataFrame([row for row in product(*d.values())],columns=d.keys()) | |
| y = pd.Series(self.y) | |
| y_train = y.values[:-test_length] | |
| y_test = y.values[-test_length:] | |
| dates = pd.to_datetime(self.current_dates) | |
| scores = [] # lower is better | |
| if y.min() > 0: | |
| grid = expand_grid({ | |
| 'trend':[None,'add','mul'], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from Forecaster import Forecaster | |
| forecasts = {} | |
| for c in df.columns: | |
| f = Forecaster(y=df[c].to_list(),current_dates=df.index.to_list(),name=c) | |
| f.generate_future_dates(24,'MS') # MS is from pandas, meaning month start | |
| forecasts[c] = f |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from Forecaster import Forecaster | |
| forecasts = {} | |
| test_length = 3 | |
| for sym in ('UTPHCI','FLPHCI'): | |
| f = Forecaster() | |
| f.get_data_fred(sym) | |
| f.generate_future_dates(24,'MS') | |
| f.forecast_auto_arima(test_length=test_length) | |
| f.forecast_ets(test_length=test_length) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def visualize_results(): | |
| def on_button_clicked(b): | |
| ts_selection = ts_dropdown.value | |
| mo_selection = list(mo_select.value) | |
| with output: | |
| clear_output() | |
| forecasts[ts_selection].plot(models=mo_selection, print_mapes=True, plot_fitted=True) | |
| models = ('auto_arima','tbats','ets','average') | |
| ts_dropdown = widgets.Dropdown(options=forecasts.keys(), description = 'Time Series:') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def accuracy_trend_vis(): | |
| """ visualize the model error metrics | |
| leverages Jupyter widgets | |
| """ | |
| def display_user_selections(ts_selection): | |
| """ displays graphs with seaborn based on what user selects from dropdown menus | |
| """ | |
| f = forecasts[ts_selection] | |
| k = Counter(f.mape) | |
| mapes = [h[1] for h in k.most_common()] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| data_state.fillna(method='ffill',inplace=True) | |
| max_perc = lambda x: np.max([i/100 for i in x]) | |
| data_state_piv = pd.pivot_table(data_state, | |
| index='location', | |
| values=['total_vaccinations_per_hundred','people_vaccinated','people_fully_vaccinated_per_hundred','people_fully_vaccinated'], | |
| aggfunc={'total_vaccinations_per_hundred':max_perc, | |
| 'people_fully_vaccinated_per_hundred':max_perc, | |
| 'people_vaccinated':np.max, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| sns.set(rc={'figure.figsize':(14,8)}) | |
| ax3 = sns.lineplot(x=data_cumsum_doses.index,y='total_distributed',data=data_cumsum_doses,color='purple',linestyle='--',label='Total Distributed') | |
| sns.lineplot(x=data_cumsum_doses.index,y='people_vaccinated',data=data_cumsum_doses,color='red',label='Received at Least One Dose') | |
| sns.lineplot(x=data_cumsum_doses.index,y='people_fully_vaccinated',data=data_cumsum_doses,color='orange',label='Fully Vaccinated') | |
| above_val = 5000000 | |
| plt.text(data_cumsum_doses.index.to_list()[-1], data_cumsum_doses['total_distributed'].max() + above_val, '{:,.0f}'.format(data_cumsum_doses['total_distributed'].max()),size=14,color = 'purple') | |
| plt.text(data_cumsum_doses.index.to_list()[-1], data_cumsum_doses['people_vaccinated'].max() + above_val, '{:,.0f}'.format(data_cumsum_doses['people_vaccinated'].max()),size=14,color = 'red') | |
| plt.text(data_cumsum_doses.index.to_list()[-1], data_cumsum_doses['people_fully_vaccinated'].max() + above_val, '{:,.0f}'.format(data_cumsum_doses['people_fully_vaccina |
OlderNewer