Created
August 31, 2018 04:57
-
-
Save JonnoFTW/f94f8d97e57f6796da83b834ce66aa45 to your computer and use it in GitHub Desktop.
Predicting Eye Open/Closed State Using Keras LSTM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from keras import Sequential | |
from keras.layers import LSTM, Dense, Dropout, Activation | |
import numpy as np | |
from matplotlib import pyplot | |
def get_data(): | |
data = pd.read_csv('eeg_data.csv') | |
return data[data.apply(lambda x: np.abs(x - x.median()) / x.std() < 4).all(axis=1)] | |
def show(df: pd.DataFrame): | |
# create a subplot for each time series | |
fig, plots = pyplot.subplots(len(df.columns), squeeze=False) | |
for i, name in enumerate(df.columns): | |
subplot = plots[i][0] | |
subplot.set_title(name, loc='right', x=1.05) | |
subplot.plot(df[name]) | |
fig.subplots_adjust(bottom=0.2) | |
pyplot.show() | |
def do_model(data): | |
batch_size = 1 | |
training_split = 0.25 | |
num_fields = 14 | |
y_col = 'eyeDetection' | |
num_rows = len(data) | |
xs = data.drop(y_col, axis=1).values.astype(np.float32).reshape(num_rows, 1, num_fields) | |
ys = data[y_col].values.astype(np.float32) | |
model = Sequential() | |
model.add(LSTM(128, | |
stateful=True, | |
return_sequences=True, | |
batch_input_shape=(batch_size, 1, num_fields))) | |
model.add(Dropout(0.1)) | |
model.add(LSTM(128, stateful=True)) | |
model.add(Dropout(.1)) | |
model.add(Dense(1, activation='sigmoid' )) | |
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy']) | |
model.fit(xs, ys, batch_size=batch_size, validation_split=training_split, epochs=5, shuffle=False) | |
data['Predicted EyeState'] = model.predict(xs, batch_size=batch_size) | |
if __name__ == "__main__": | |
filtered = get_data() | |
do_model(filtered) | |
show(filtered) |
I think the window size is 1, since batch shape is (1,1, 14). Ie each batch is 1 point in time with 14 fields.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What's the window size in your model?