sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)
model = mx.model.FeedForward(symbol=sym, arg_params=arg_params, aux_params=aux_params, ctx=devices, ...)
model.fit(...)
Functions in fine_tune.py
# Get pre-trained model
model = mx.model.FeedForward.load(prefix, 0)
# Get cutoff layer
cutoff = get_feature_symbol(model)
# Add new layers
new_net = mx.sym.FullyConnected(data=cutoff, num_hidden=10, name="new_fc")
new_net = mx.sym.SoftmaxOutput(data=new_net, name="softmax")
# Train
new_model = finetune(symbol=new_net, model=model, X=X, y=y, num_epoch=0)