Skip to content

Instantly share code, notes, and snippets.

@jgoodie
Created May 29, 2024 03:00
Show Gist options
  • Save jgoodie/a9230cdeefce4a814fc93a882cbf6241 to your computer and use it in GitHub Desktop.
Save jgoodie/a9230cdeefce4a814fc93a882cbf6241 to your computer and use it in GitHub Desktop.
torch.manual_seed(101)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# input_features = 91
input_features = 81
output_features = 4
hidden_units = 128 #128
dropout = 0.0
lr = 0.001
weight_decay = 0.0
epochs = 2000
# Create an instance of the IoT and send it to the target device
model = IoTMultiClassModel(input_features=input_features,
output_features=output_features,
hidden_units=hidden_units,
dropout=dropout).to(device)
print(model)
train_losses, train_accs, val_losses, val_accs = training_loop(model, X_train, X_val, y_train, y_val,
epochs = epochs, weight_decay = weight_decay, lr=lr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment