Skip to content

Instantly share code, notes, and snippets.

@lgray
Created November 19, 2021 17:48
Show Gist options
  • Select an option

  • Save lgray/8eecd3f5d73a3a6aad1d4fd2ff0dde8f to your computer and use it in GitHub Desktop.

Select an option

Save lgray/8eecd3f5d73a3a6aad1d4fd2ff0dde8f to your computer and use it in GitHub Desktop.
import torch
from torch_cmspepr.gravnet_model import GravnetModelWithNoiseFilter
model = GravnetModelWithNoiseFilter(input_dim=9, output_dim=6, k=50, signal_threshold=.05)
weights = torch.load("ckpt_train_taus_integrated_noise_Oct20_212115_best_397.pth.tar")
model.load_state_dict(weights["model"])
jitted = torch.jit.script(model)
torch.jit.save(jitted, "ckpt_train_taus_integrated_noise_Oct20_212115_best_397.ptj")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment