Skip to content

Instantly share code, notes, and snippets.

@BenjaminWegener
Created April 14, 2021 18:02
Show Gist options
  • Save BenjaminWegener/311292080a71becbe5a8c0cc7657657d to your computer and use it in GitHub Desktop.
Save BenjaminWegener/311292080a71becbe5a8c0cc7657657d to your computer and use it in GitHub Desktop.
Causal Attention Layer Tensorflow.js
//---------------------------------------------------
//setup
var dModel = 256; //dimension of embedding
var INPUTSIZE = 1024; //length of input
var BLOCKS = 3;
var LEARNING_RATE = 0.003;
var OPTIMIZER = tf.train.adam(LEARNING_RATE, beta_1 = 0.85, beta_2 = 0.9, epsilon = 1e-9);
var LOSS = 'categoricalCrossentropy';
var INITIALIZER = 'GlorotUniform';
/******************************************************************************
* tensorflow.js lambda layer
* written by twitter.com/benjaminwegener
* license: MIT
*/
class lambdaLayer extends tf.layers.Layer {
constructor(config) {
super(config);
if (config.name === undefined) {
config.name = ((+new Date) * Math.random()).toString(36); //random name from timestamp in case name hasn't been set
}
this.name = config.name;
this.lambdaFunction = config.lambdaFunction;
this.lambdaOutputShape = config.lambdaOutputShape;
}
call(input) {
return tf.tidy(() => {
let result = null;
eval(this.lambdaFunction);
return result;
});
}
computeOutputShape(inputShape) {
if (this.lambdaOutputShape === undefined) { //if no outputshape provided, try to set as inputshape
return inputShape[0];
} else {
return this.lambdaOutputShape;
}
}
getConfig() {
const config = super.getConfig();
Object.assign(config, {
lambdaFunction: this.lambdaFunction,
lambdaOutputShape: this.lambdaOutputShape
});
return config;
}
static get className() {
return 'lambdaLayer';
}
}
tf.serialization.registerClass(lambdaLayer);
//---------------------------------------------------
//Attention Mechanism with causal masking as seperate model:
let queries = tf.input({name: 'queries', shape: [INPUTSIZE, dModel]});
let keys = tf.input({name: 'keys', shape: [INPUTSIZE, dModel]});
let values = tf.input({name: 'values', shape: [INPUTSIZE, dModel]});
let outputs = new lambdaLayer({name: 'causalAtt', lambdaFunction: `
result = tf.matMul(input[0], input[1], false, true);
result = tf.mul(result, tf.sqrt(tf.scalar(dModel)));
let causalMask = Array(result.shape[1]).fill().map(() => Array(result.shape[2]).fill(-1e9));
for (let h = 0; h < result.shape[1]; h++){
for (let w = 0; w < result.shape[2]; w++){
if (h <= w){
causalMask[w][h] = 0;
}
}
}
result = tf.add(result, tf.tensor(causalMask));
result = tf.softmax(result, -1);
result = tf.matMul(result, input[2]);
`, lambdaOutputShape: [values.shape]}).apply([queries, keys, values]);
attention_model = tf.model({inputs: [queries, keys, values], outputs: outputs});
attention_model.compile({optimizer: OPTIMIZER, loss: LOSS}); //Don't know if this is neccessary
attention_model.summary();
//---------------------------------------------------
//use it in your model as follows:
queries = tf.layers.dense({units: dModel}).apply(x);
keys = tf.layers.dense({units: dModel}).apply(x);
values = tf.layers.dense({units: dModel}).apply(x);
x = attention_model.apply([queries, keys, values]); //HERE THE MODEL FOR ATTENTION IS USED
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment