Skip to content

Instantly share code, notes, and snippets.

@sagelywizard
Last active August 30, 2017 21:42
Show Gist options
  • Save sagelywizard/ef6d46b147507971e9938e1cc4ce25ce to your computer and use it in GitHub Desktop.
Save sagelywizard/ef6d46b147507971e9938e1cc4ce25ce to your computer and use it in GitHub Desktop.
A script for classifying a single sample using my model for the ml4seti competition.
#!/usr/bin/env python3
"""A script for classifying a single sample using my model for the ml4seti competition.
e.g. python class_prob.py /path/to/sample.dat /path/to/model.pth
"""
import sys
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import ibmseti
from model import DenseNet
def get_spectrogram(filename):
raw_file = open(filename, 'rb')
aca = ibmseti.compamp.SimCompamp(raw_file.read())
tensor = torch.from_numpy(aca.get_spectrogram()).float().view(1, 1, 384, 512)
return Variable(tensor, volatile=True)
def get_densenet(model_path):
dense = DenseNet(False)
dense.eval()
state = torch.load(model_path)
dense.load_state_dict(state['model'])
return dense
def main(filename, model_path):
spec = get_spectrogram(filename)
model = get_densenet(model_path)
print(F.softmax(model(spec)).data.view(7).tolist())
if __name__ == '__main__':
main(sys.argv[1], sys.argv[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment