Skip to content

Instantly share code, notes, and snippets.

@jrzaurin
Last active February 19, 2021 22:54
Show Gist options
  • Save jrzaurin/ba90183389e2d34fb2d4b90a8799763b to your computer and use it in GitHub Desktop.
Save jrzaurin/ba90183389e2d34fb2d4b90a8799763b to your computer and use it in GitHub Desktop.
Train a TabMlp model
from pytorch_widedeep import Trainer
from pytorch_widedeep.metrics import Accuracy
trainer = Trainer(model, objective="binary", metrics=[(Accuracy)])
trainer.fit(X_tab=X_tab, target=target, n_epochs=5, batch_size=256, val_split=0.2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment