A Pen by Shanqing Cai on CodePen.
Last active
August 12, 2024 09:02
-
-
Save caisq/33ed021e0c7b9d0e728cb1dce399527d to your computer and use it in GitHub Desktop.
Custom Layers in TensorFlow.js (Stateful, Configurable, and Serializable)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Define a custom layer. | |
* | |
* This layer performs the following simple operation: | |
* output = input * (x ^ alpha); | |
* - x is a trainable scalar weight. | |
* - alpha is a configurable constant. | |
* | |
* This custom layer is written in a way that can be saved and loaded. | |
*/ | |
class TimesXToThePowerOfAlphaLayer extends tf.layers.Layer { | |
constructor(config) { | |
super(config); | |
this.alpha = config.alpha; | |
} | |
/** | |
* build() is called when the custom layer object is connected to an | |
* upstream layer for the first time. | |
* This is where the weights (if any) are created. | |
*/ | |
build(inputShape) { | |
this.x = this.addWeight('x', [], 'float32', tf.initializers.ones()); | |
} | |
/** | |
* call() contains the actual numerical computation of the layer. | |
* | |
* It is "tensor-in-tensor-out". I.e., it receives one or more | |
* tensors as the input and should produce one or more tensors as | |
* the return value. | |
* | |
* Be sure to use tidy() to avoid WebGL memory leak. | |
*/ | |
call(input) { | |
return tf.tidy(() => { | |
const k = tf.pow(this.x.read(), this.alpha); | |
return tf.mul(input[0], k); | |
}); | |
} | |
/** | |
* getConfig() generates the JSON object that is used | |
* when saving and loading the custom layer object. | |
*/ | |
getConfig() { | |
const config = super.getConfig(); | |
Object.assign(config, {alpha: this.alpha}); | |
return config; | |
} | |
/** | |
* The static className getter is required by the | |
* registration step (see below). | |
*/ | |
static get className() { | |
return 'TimesXToThePowerOfAlphaLayer'; | |
} | |
} | |
/** | |
* Regsiter the custom layer, so TensorFlow.js knows what class constructor | |
* to call when deserializing an saved instance of the custom layer. | |
*/ | |
tf.serialization.registerClass(TimesXToThePowerOfAlphaLayer); | |
(async function main() { | |
const model = tf.sequential(); | |
model.add(tf.layers.dense({units: 1, inputShape: [4]})); | |
// Here comes an instance of the custom layer. | |
model.add(new TimesXToThePowerOfAlphaLayer({alpha: 1.5})); | |
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); | |
model.summary(); | |
// Train the model using some random data. | |
const xs = tf.randomNormal([2, 4]); | |
const ys = tf.randomNormal([2, 1]); | |
await model.fit(xs, ys, { | |
epochs: 5, | |
callbacks: { | |
onEpochEnd: async (epoch, logs) => { | |
console.log(`Epoch {epoch}: loss = ${logs.loss}`); | |
} | |
} | |
}); | |
// Save the model and load it back. | |
await model.save('indexeddb://codepen-tfjs-model-example-jdBgwB-v1'); | |
console.log('Model saved.'); | |
const model2 = await tf.loadModel('indexeddb://codepen-tfjs-model-example-jdBgwB-v1'); | |
console.log('Model2 loaded.') | |
console.log('The two predict() outputs should be identical:'); | |
model.predict(xs).print(); | |
model2.predict(xs).print(); | |
})(); |
I've done a PR to tfjs for an AttentionLayer if it's helpful to you.
Nice, thank you.
Thanks for example.
Just small extension, if you are working with 2 input layers(like a+b)
...
computeOutputShape(inputShape)
{
return inputShape[0];
}
call(input)
{
return tf.tidy(() =>
{
var a = input[0], b=input[1];
return a.add(b);
}
}
...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I've done a PR to tfjs for an AttentionLayer if it's helpful to you.