This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Define a Wide + Deep model for classification on structured data.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import multiprocessing | |
import tensorflow as tf | |
from tensorflow.python.lib.io import file_io | |
import json |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def data_reshape_for_model(data_in,n_timesteps,n_features,print_info=True): | |
''' Function to reshape the data into model ready format, either for training or prediction. | |
''' | |
# get original data shape | |
data_in_shape = data_in.shape | |
# create a dummy row with desired shape and one empty observation | |
data_out = np.zeros((1,n_timesteps,n_features)) | |
# loop though each row of data and reshape accordingly | |
for row in range(len(data_in)): | |
# for each row look ahead as many timesteps as needed and then transpose the data to give shape keras wants |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from numpy import concatenate | |
from matplotlib import pyplot | |
from keras.models import Sequential | |
from keras.callbacks import Callback | |
from keras.layers import LSTM, Dense, Activation | |
import matplotlib.pyplot as plt | |
%matplotlib inline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# make some noisy but smooth looking data | |
data = np.sqrt(np.random.rand(N_DATA_ORIG,N_FEATURES)) | |
df_data = pd.DataFrame(data) | |
df_data = df_data.rolling(window=N_ROLLING).mean() | |
df_data = df_data.dropna() | |
df_data = df_data.head(N_DATA) | |
print(df_data.shape) | |
data = df_data.values | |
# plot the normal healthy data | |
fig, ax = plt.subplots(num=None, figsize=(14, 6), dpi=80, facecolor='w', edgecolor='k') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# make some random data | |
data_rand = np.random.rand(N_DATA,N_FEATURES) | |
data_new = np.copy(data) | |
# at a random point for a certain number of steps, swap out the smooth data with some random data | |
data_new[random_break_point:(random_break_point+BREAK_LEN)] = data_rand[random_break_point:(random_break_point+BREAK_LEN)] | |
# plot the new data | |
fig, ax = plt.subplots(num=None, figsize=(14, 6), dpi=80, facecolor='w', edgecolor='k') | |
size = len(data_new) | |
for x in range(data_new.shape[1]): | |
ax.plot(range(0,size), data_new[:,x], '-', linewidth=0.5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# build network | |
model = Sequential() | |
# add number of layer specified | |
for layer in range(N_LAYERS): | |
model.add(LSTM(N_LSTM_UNITS,input_shape=(N_TIMESTEPS,N_FEATURES),return_sequences=True)) | |
model.add(Dense(N_FEATURES)) | |
model.compile(loss='mae', optimizer='adam') | |
# print model summary | |
print(model.summary()) | |
# reshape data for training |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
df_out = yhat_to_df_out(data_train,yhat,N_TIMESTEPS,N_FEATURES) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plot_cols = [col for col in df_out.columns if 'error_avg' in col] | |
print(plot_cols) | |
# plot the new data | |
fig, ax = plt.subplots(num=None, figsize=(14, 6), dpi=80, facecolor='w', edgecolor='k') | |
size = len(df_out) | |
for col in plot_cols: | |
ax.plot(range(0,size), df_out[col], '-', linewidth=0.5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# now train on new data | |
print(f'... reshaping data for new data training ...') | |
data_train_new = data_reshape_for_model(data_new,N_TIMESTEPS,N_FEATURES) | |
print("... begin training on new data ...") | |
model = train(model,data_train_new,n_epochs=1) | |
yhat_new = predict(model,data_train_new) | |
df_out_new = yhat_to_df_out(data_train_new,yhat_new,N_TIMESTEPS,N_FEATURES) | |
plot_cols = [col for col in df_out_new.columns if 'error_avg' in col] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PS C:\Users\amaguire\Documents\GitHub\ami-research\ami> pipenv install jupyterlab | |
Installing jupyterlab… | |
Adding jupyterlab to Pipfile's [packages]… | |
Installation Succeeded | |
Pipfile.lock (adb84f) out of date, updating to (ca72e7)… | |
Locking [dev-packages] dependencies… | |
Locking [packages] dependencies… | |
Success! | |
Updated Pipfile.lock (adb84f)! | |
Installing dependencies from Pipfile.lock (adb84f)… |
OlderNewer