Created
January 19, 2019 13:07
-
-
Save wkcn/69f0f6d2ca467816dc481a00c225104f to your computer and use it in GitHub Desktop.
test_fcn_for_mxnet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ['PYTHONUNBUFFERED'] = '1' | |
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' | |
os.environ['MXNET_ENABLE_GPU_P2P'] = '0' | |
import time | |
import logging | |
import argparse | |
import mxnet as mx | |
from mxnet import gluon, autograd, nd | |
import numpy as np | |
import random | |
from mxnet.gluon.model_zoo import vision | |
from mxnet.gluon import nn | |
from mxnet.gluon.model_zoo import vision | |
def get_backend(num_filters): | |
out = nn.HybridSequential('backend_') | |
with out.name_scope(): | |
for num_filter in num_filters: | |
out.add(nn.Conv2D(channels = num_filter, kernel_size = (3, 3), strides = (1, 1), padding = (2, 2), dilation = (2, 2), use_bias = True, activation = 'relu')) | |
out.add(nn.Conv2D(channels = 1, kernel_size = (1, 1))) | |
out.add(nn.HybridLambda(lambda F, x : F.squeeze(x, axis = 1))) | |
return out | |
def main(): | |
devs = [mx.gpu(i) for i in [0, 1, 2, 3]] | |
print('Using Device {}'.format(devs)) | |
vgg_part = vision.vgg16(pretrained=False).features[0:23] | |
backend_num_filters = [512, 256, 128, 64] | |
backend = get_backend(backend_num_filters) | |
net = nn.HybridSequential() | |
net.add(vgg_part) | |
net.add(backend) | |
net.collect_params().initialize(mx.init.Normal(sigma=0.01)) | |
net.hybridize(static_alloc=True) | |
net.collect_params().reset_ctx(devs) | |
# Trainer | |
trainer = gluon.Trainer( | |
net.collect_params(), 'adam', | |
dict( | |
learning_rate = 0.01, | |
wd = 1e-4, | |
) | |
) | |
outputs = [None for _ in devs] | |
losses = [None for _ in devs] | |
labels = [None for _ in devs] | |
begin_epoch = 0 | |
end_epoch = 120 | |
for epoch in range(begin_epoch, end_epoch + 1): | |
batch = None | |
batch_i = 0 | |
epoch_end = False | |
net.hybridize(static_alloc=True) | |
while not epoch_end: | |
batch_size_count = 0 | |
tic = time.time() | |
for ctx_i, ctx in enumerate(devs): | |
shape = (1, 9, 3, random.randint(300, 512), random.randint(300, 512)) | |
batch = [mx.nd.uniform(0, 1, shape)] | |
x = batch[0][0].as_in_context(ctx) | |
batch_size_count += len(x) | |
with autograd.record(): | |
losses[ctx_i] = net(x).mean() | |
else: | |
autograd.backward(losses) | |
trainer.step(batch_size_count) | |
mx.nd.waitall() | |
if batch_i % 1 == 0: | |
info = '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec'.format(epoch, batch_i, batch_size_count / (time.time() - tic)) | |
print(info) | |
batch_i += 1 | |
batch_i = 0 | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment