Created
April 23, 2020 20:51
-
-
Save gkorland/cd043506b7cf4f78c967cda760ee9797 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from sys import platform | |
from models import * | |
from utils.datasets import * | |
from utils.utils import * | |
def export(): | |
img_size = (320, 192) | |
weights = opt.weights | |
device = 'cpu' | |
# Initialize model | |
model = Darknet(opt.cfg, img_size) | |
# Load weights | |
attempt_download(weights) | |
if weights.endswith('.pt'): # pytorch format | |
model.load_state_dict(torch.load(weights, map_location=device)['model']) | |
else: # darknet format | |
load_darknet_weights(model, weights) | |
# Eval mode | |
model.eval() | |
# Fuse Conv2d + BatchNorm2d layers | |
model.fuse() | |
img = torch.zeros((1, 3) + img_size) # (1, 3, 320, 192) | |
traced = torch.jit.trace(model, [img]) | |
torch.jit.save(traced, opt.outfile) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='*.cfg path') | |
parser.add_argument('--weights', type=str, default='weights/yolov3-spp-ultralytics.pt', help='weights path') | |
parser.add_argument('--outfile', type=str, default='yolov3-spp-traced.pt', help='exported file path') | |
opt = parser.parse_args() | |
print(opt) | |
with torch.no_grad(): | |
export() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment