Last active
December 25, 2019 21:44
-
-
Save alik604/a571b8560f759800b1710204dd1fae06 to your computer and use it in GitHub Desktop.
EEGNet in Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def EEGNet(nb_classes, Chans = 64, Samples = 128, | |
dropoutRate = 0.5, kernLength = 64, F1 = 8, | |
D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'): | |
""" Keras Implementation of EEGNet | |
http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta | |
Inputs: | |
nb_classes : int, number of classes to classify | |
Chans, Samples : number of channels and time points in the EEG data | |
dropoutRate : dropout fraction | |
kernLength : length of temporal convolution in first layer. We found | |
that setting this to be half the sampling rate worked | |
well in practice. For the SMR dataset in particular | |
since the data was high-passed at 4Hz we used a kernel | |
length of 32. | |
F1, F2 : number of temporal filters (F1) and number of pointwise | |
filters (F2) to learn. Default: F1 = 8, F2 = F1 * D. | |
D : number of spatial filters to learn within each temporal | |
convolution. Default: D = 2 | |
dropoutType : Either SpatialDropout2D or Dropout, passed as a string. | |
""" | |
from keras.layers import * # like any other bad programmer would do :) | |
# nb_classes =10 | |
# Chans = 64 | |
# Samples = 128 | |
# dropoutRate = 0.5 | |
# kernLength = 64 | |
# F1 = 8 | |
# D = 2 | |
# F2 = 16 | |
# norm_rate = 0.25 | |
# dropoutType = Dropout | |
if dropoutType == 'SpatialDropout2D': | |
dropoutType = SpatialDropout2D | |
elif dropoutType == 'Dropout': | |
dropoutType = Dropout | |
else: | |
raise ValueError('dropoutType must be one of SpatialDropout2D ' | |
'or Dropout, passed as a string.') | |
model = Sequential() | |
model.add(Conv2D(F1, (1,kernLength) , padding = 'same', input_shape = (1,Chans,Samples), use_bias= False)) | |
model.add(BatchNormalization(axis = 1)) | |
model.add(DepthwiseConv2D((Chans, 1), use_bias = False, | |
depth_multiplier = D, | |
depthwise_constraint = max_norm(1.))) | |
model.add(BatchNormalization(axis = 1) | |
model.add(Activation('elu')) | |
model.add(AveragePooling2D((1, 4))) | |
model.add(dropoutType(dropoutRate)) | |
model.add(SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same')) | |
model.add(BatchNormalization(axis = 1)) | |
model.add(Activation('elu')) | |
model.add(AveragePooling2D((1, 8))) | |
model.add(dropoutType(dropoutRate)) | |
model.add(Flatten(name = 'flatten')) | |
model.add(Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(norm_rate))) | |
model = Model(inputs=input1, outputs=Activation('softmax', name = 'softmax')) | |
return Mmodel |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment