Skip to content

Instantly share code, notes, and snippets.

@mongoose54
Created December 25, 2017 21:54
Show Gist options
  • Save mongoose54/8d47be1359691bae3b2470dfab60fc00 to your computer and use it in GitHub Desktop.
Save mongoose54/8d47be1359691bae3b2470dfab60fc00 to your computer and use it in GitHub Desktop.
def train(net, train_iter, val_iter, batch_size, epochs, ctx):
learning_rate = 0.01
wd = 0.002
log_interval=100
if isinstance(ctx, mx.Context):
ctx = [ctx]
trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': learning_rate})
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
best_f1 = 0
# logging.info('[Initial] validation: %s'%(metric_str(val_names, val_accs)))
mean_loss = []
for epoch in range(epochs):
tic = time.time()
train_iter.reset()
btic = time.time()
for i, batch in enumerate(train_iter):
# the model zoo models expect normalized images
# data = color_normalize(batch.data[0]/255,
# mean=mx.nd.array([0.485, 0.456, 0.406]).reshape((1,3,1,1)),
# std=mx.nd.array([0.229, 0.224, 0.225]).reshape((1,3,1,1)))
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
Ls = []
with autograd.record():
for x, y in zip(data, label):
z = net(x)
# rescale the loss based on class to counter the imbalance problem
L = loss(z, y)
mean_loss.append(np.mean(L.asnumpy()))
# store the loss and do backward after we have done forward
# on all GPUs for better speed on multiple GPUs.
Ls.append(L)
outputs.append(z)
for L in Ls:
L.backward()
trainer.step(batch.data[0].shape[0])
names, accs = metric.get()
metric.reset()
logging.info('[Epoch %d] training: %s'%(epoch, metric_str(names, accs)))
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
logging.info('[Epoch %d] mean loss: %s'%(epoch, np.mean(mean_loss)))
AUC(net, val_iter, ctx, batch_size=batch_size)
batch_size=128
# load dataset
train_iter = mx.io.ImageRecordIter(path_imgrec="train.bin",
path_imglist="nih_chest14_train.lst",
label_width=15,
min_img_size=256,
data_shape=(1, 224, 224),
rand_crop=True,
shuffle=True,
batch_size=batch_size,
max_random_scale=1.5,
min_random_scale=0.75,
rand_mirror=True)
val_iter = mx.io.ImageRecordIter(path_imgrec="val.bin",
path_imglist="nih_chest14_val.lst",
label_width=15,
min_img_size=256,
data_shape=(1, 224, 224),
batch_size=batch_size)
gpus = 4
contexts = [mx.gpu(i) for i in range(0,gpus)] if gpus > 0 else [mx.cpu()]
net = gluon.model_zoo.vision.densenet121(classes=15)
print net
# inputs = mx.sym.var('data')
# out = net(inputs)
# internals = out.get_internals()
# print internals
# outputs = internals['densenet0_flatten0_reshape0_output']
# feat_model = gluon.SymbolBlock(outputs, inputs, params=net.collect_params())
# net.collect_params().reset_ctx(contexts)
net.collect_params().initialize(mx.init.Xavier(), ctx=contexts)
train(net, train_iter, val_iter, batch_size, epochs=100, ctx=contexts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment