Skip to content

Instantly share code, notes, and snippets.

@erap129
Last active September 28, 2021 07:30
Show Gist options
  • Select an option

  • Save erap129/843af702a9b3a1100776942342d6a503 to your computer and use it in GitHub Desktop.

Select an option

Save erap129/843af702a9b3a1100776942342d6a503 to your computer and use it in GitHub Desktop.
NASA RUL project - get lstm results
from tqdm import tqdm
def get_predictions_and_labels_lstm(dataset):
predictions = []
labels = []
for item in tqdm(test_dataset):
X, y = item
_, output = trained_model(X.unsqueeze(dim=0))
predictions.append(output.item())
labels.append(y.item())
return predictions, labels
print(f'RMSE on train set: {mean_squared_error(*get_predictions_and_labels_lstm(data_module.train_dataset), squared=False)}')
print(f'RMSE on test set: {mean_squared_error(*get_predictions_and_labels_lstm(data_module.test_dataset), squared=False)}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment