Skip to content

Instantly share code, notes, and snippets.

@rozeappletree
Created January 22, 2020 10:58
Show Gist options
  • Save rozeappletree/7bfe09d136d62921e70967ecaaaa6b1a to your computer and use it in GitHub Desktop.
Save rozeappletree/7bfe09d136d62921e70967ecaaaa6b1a to your computer and use it in GitHub Desktop.
Transfer Learning
// modify the pre-trained mobilenet.
// freezes layers to train only the last couple of layers
async function getModifiedMobilenet()
{
// freezes mobilenet layers to make them untrainable
// just keeps final layers trainable with argument trainableLayers
async function freezeModelLayers(trainableLayers,mobilenetModified)
{
for (const layer of mobilenetModified.layers)
{
layer.trainable = false;
for (const tobeTrained of trainableLayers)
{
if (layer.name.indexOf(tobeTrained) === 0)
{
layer.trainable = true;
break;
}
}
}
return mobilenetModified;
}
const trainableLayers = ['denseModified','conv_pw_13_bn','conv_pw_13','conv_dw_13_bn','conv_dw_13'];
const mobilenet = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
console.log('Mobilenet model is loaded')
const x = mobilenet.getLayer('global_average_pooling2d_1');
const predictions = tf.layers.dense({units: 2, activation: 'softmax',name: 'denseModified'}).apply(x.output);
let mobilenetModified = tf.model({inputs: mobilenet.input, outputs: predictions, name: 'modelModified' });
console.log('Mobilenet model is modified')
mobilenetModified = freezeModelLayers(trainableLayers,mobilenetModified)
console.log('ModifiedMobilenet model layers are freezed')
mobilenetModified.compile({loss: 'categoricalCrossentropy', optimizer: tf.train.adam(1e-3), metrics: ['accuracy','crossentropy']});
return mobilenetModified
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment