Created
December 25, 2017 21:54
-
-
Save mongoose54/8d47be1359691bae3b2470dfab60fc00 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(net, train_iter, val_iter, batch_size, epochs, ctx): | |
learning_rate = 0.01 | |
wd = 0.002 | |
log_interval=100 | |
if isinstance(ctx, mx.Context): | |
ctx = [ctx] | |
trainer = gluon.Trainer(net.collect_params(), 'adam', | |
{'learning_rate': learning_rate}) | |
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() | |
best_f1 = 0 | |
# logging.info('[Initial] validation: %s'%(metric_str(val_names, val_accs))) | |
mean_loss = [] | |
for epoch in range(epochs): | |
tic = time.time() | |
train_iter.reset() | |
btic = time.time() | |
for i, batch in enumerate(train_iter): | |
# the model zoo models expect normalized images | |
# data = color_normalize(batch.data[0]/255, | |
# mean=mx.nd.array([0.485, 0.456, 0.406]).reshape((1,3,1,1)), | |
# std=mx.nd.array([0.229, 0.224, 0.225]).reshape((1,3,1,1))) | |
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | |
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | |
outputs = [] | |
Ls = [] | |
with autograd.record(): | |
for x, y in zip(data, label): | |
z = net(x) | |
# rescale the loss based on class to counter the imbalance problem | |
L = loss(z, y) | |
mean_loss.append(np.mean(L.asnumpy())) | |
# store the loss and do backward after we have done forward | |
# on all GPUs for better speed on multiple GPUs. | |
Ls.append(L) | |
outputs.append(z) | |
for L in Ls: | |
L.backward() | |
trainer.step(batch.data[0].shape[0]) | |
names, accs = metric.get() | |
metric.reset() | |
logging.info('[Epoch %d] training: %s'%(epoch, metric_str(names, accs))) | |
logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic)) | |
logging.info('[Epoch %d] mean loss: %s'%(epoch, np.mean(mean_loss))) | |
AUC(net, val_iter, ctx, batch_size=batch_size) | |
batch_size=128 | |
# load dataset | |
train_iter = mx.io.ImageRecordIter(path_imgrec="train.bin", | |
path_imglist="nih_chest14_train.lst", | |
label_width=15, | |
min_img_size=256, | |
data_shape=(1, 224, 224), | |
rand_crop=True, | |
shuffle=True, | |
batch_size=batch_size, | |
max_random_scale=1.5, | |
min_random_scale=0.75, | |
rand_mirror=True) | |
val_iter = mx.io.ImageRecordIter(path_imgrec="val.bin", | |
path_imglist="nih_chest14_val.lst", | |
label_width=15, | |
min_img_size=256, | |
data_shape=(1, 224, 224), | |
batch_size=batch_size) | |
gpus = 4 | |
contexts = [mx.gpu(i) for i in range(0,gpus)] if gpus > 0 else [mx.cpu()] | |
net = gluon.model_zoo.vision.densenet121(classes=15) | |
print net | |
# inputs = mx.sym.var('data') | |
# out = net(inputs) | |
# internals = out.get_internals() | |
# print internals | |
# outputs = internals['densenet0_flatten0_reshape0_output'] | |
# feat_model = gluon.SymbolBlock(outputs, inputs, params=net.collect_params()) | |
# net.collect_params().reset_ctx(contexts) | |
net.collect_params().initialize(mx.init.Xavier(), ctx=contexts) | |
train(net, train_iter, val_iter, batch_size, epochs=100, ctx=contexts) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment