Skip to content

Instantly share code, notes, and snippets.

@shnhrtkyk
Last active February 18, 2020 08:18
Show Gist options
  • Save shnhrtkyk/ca39c0b69ecfa42c0c80a2cf2bf07e3c to your computer and use it in GitHub Desktop.
Save shnhrtkyk/ca39c0b69ecfa42c0c80a2cf2bf07e3c to your computer and use it in GitHub Desktop.
pytorchのモデルをJITでコンパイル
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from models import modelName
import numpy as np
import cv2
batchsize =1
input_channel = 3
output_channel= 3
input_height = input_width = 256
model = modelName(input_channel, output_channel).cuda()
model.load_state_dict(torch.load('./model.pth')) # pytorchモデルを読む
model.eval()
example = torch.rand(1, 3, 256, 256) #入力用に定義
device = torch.device("cuda")
mode = "cuda"
model = torch.jit.trace(model, example.to(device))
model.save("Net_h{}_w{}_{}.pt".format(input_height, input_width, mode))
print("Net_h{}_w{}_{}.pt is exported".format(output_channel, input_height, input_width, mode))
input_x_np = np.zeros((batchsize, input_channel, input_height, input_width)).astype(np.float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment