|
def get_feature_symbol(model, top_layer=None): |
|
"""Get feature symbol from a model |
|
.. note:: |
|
If top_layer is not present, it will return the second last layer symbol |
|
Parameters |
|
---------- |
|
model: mx.model.FeedForward |
|
Model will be used to extract feature symbol |
|
top_layer: str, option |
|
Name of top_layer will be used |
|
Returns |
|
------- |
|
internals[top_layer]: mx.symbol.Symbol |
|
Feature symbol |
|
""" |
|
if type(model) is mx.symbol.Symbol: |
|
internals = model.get_internals() |
|
else: |
|
internals = model.symbol.get_internals() |
|
tmp = internals.list_outputs()[::-1] |
|
outputs = [name for name in tmp if name.endswith("output")] |
|
if top_layer != None and type(top_layer) != str: |
|
error_msg = "top_layer must be a string in following candidates:\n %s" % "\n".join(outputs) |
|
raise TypeError(error_msg) |
|
if top_layer == None: |
|
assert len(outputs) > 3 |
|
top_layer = outputs[2] |
|
else: |
|
if top_layer not in outputs: |
|
error_msg = "%s not exists in symbol. Possible choice:\n%s" \ |
|
% (top_layer, "\n".join(outputs)) |
|
raise ValueError(error_msg) |
|
return internals[top_layer] |
|
|
|
def finetune(symbol, model, **kwargs): |
|
"""Get a FeedForward model for fine-tune |
|
.. note:: |
|
For layer doesn't exist in model, will be initialized as uniform random weight |
|
Parameters |
|
---------- |
|
symbol: mx.symbol.Symbol |
|
Symbol of new network will be finetuned. |
|
model: mx.model.FeedForward |
|
Model which contains parameters which will be used for fine-tune. |
|
kwargs: kwargs |
|
mx.model.create function's parameters |
|
Returns |
|
------- |
|
new_model: mx.model.FeedForward |
|
Finetuned model |
|
Examples |
|
-------- |
|
Load a model |
|
>>> sym, arg_params, aux_params = mx.model.load_checkpoint("model", 9) |
|
Make new symbol for finetune |
|
>>> feature = mx.model.get_feature_symbol(model) |
|
>>> net = mx.sym.FullyConnected(data=feature, num_hidden=10, name="new_fc") |
|
>>> net = mx.sym.SoftmaxOutput(data=net, name="softmax") |
|
Finetune the model |
|
>>> new_model = mx.model.finetune(symbol=net, model=model, num_epoch=2, learning_rate=1e-3, |
|
X=train, eval_data=val, |
|
batch_end_callback=mx.callback.Speedometer(100)) |
|
""" |
|
initializer = mx.init.Load(param=model.arg_params, default_init=mx.init.Uniform(0.001)) |
|
new_model = mx.model.FeedForward.create(symbol=symbol, initializer=initializer, **kwargs) |
|
return new_model |