Skip to content

Instantly share code, notes, and snippets.

@jrzaurin
Created February 19, 2021 22:54
Show Gist options
  • Save jrzaurin/67cc6df66b042432e0ae669d44dadf92 to your computer and use it in GitHub Desktop.
Save jrzaurin/67cc6df66b042432e0ae669d44dadf92 to your computer and use it in GitHub Desktop.
TabResnet
from pytorch_widedeep.models import TabResnet
tabresnet = TabResnet(
column_idx=tab_preprocessor.column_idx,
embed_input=tab_preprocessor.embeddings_input,
continuous_cols=cont_cols,
batchnorm_cont=True,
blocks_dims=[200, 100, 100],
mlp_hidden_dims=[100, 50],
)
model = WideDeep(deeptabular=tabresnet)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment