Skip to content

Instantly share code, notes, and snippets.

@khuangaf
Last active January 1, 2018 06:57
Show Gist options
  • Save khuangaf/55265b14b0af8b61392760a906509000 to your computer and use it in GitHub Desktop.
Save khuangaf/55265b14b0af8b61392760a906509000 to your computer and use it in GitHub Desktop.
def experiment(validation_datas,validation_labels,original_datas,ground_true,ground_true_times,validation_original_outputs, validation_output_times, nb_repeat, reg):
error_scores = list()
#get only the close data
ground_true = ground_true[:,:,0].reshape(-1)
ground_true_times = ground_true_times.reshape(-1)
ground_true_times = pd.to_datetime(ground_true_times, unit='s')
validation_output_times = pd.to_datetime(validation_output_times.reshape(-1), unit='s')
for i in range(nb_repeat):
model = fit_lstm(reg)
predicted = model.predict(validation_datas)
predicted_inverted = []
scaler.fit(original_datas[:,0].reshape(-1,1))
predicted_inverted.append(scaler.inverse_transform(predicted))
# since we are appending in the first dimension
predicted_inverted = np.array(predicted_inverted)[0,:,:].reshape(-1)
error_scores.append(mean_squared_error(validation_original_outputs[:,:,0].reshape(-1),predicted_inverted))
return error_scores
regs = [regularizers.l1(0),regularizers.l1(0.1), regularizers.l1(0.01), regularizers.l1(0.001), regularizers.l1(0.0001),regularizers.l2(0.1), regularizers.l2(0.01), regularizers.l2(0.001), regularizers.l2(0.0001)]
nb_repeat = 30
results = pd.DataFrame()
for reg in regs:
name = ('l1 %.4f,l2 %.4f' % (reg.l1, reg.l2))
print "Training "+ str(name)
results[name] = experiment(validation_datas,validation_labels,original_datas,ground_true,ground_true_times,validation_original_outputs, validation_output_times, nb_repeat,reg)
results.describe().to_csv('result/lstm_bias_reg.csv')
results.describe()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment