Created
April 14, 2021 18:02
-
-
Save BenjaminWegener/311292080a71becbe5a8c0cc7657657d to your computer and use it in GitHub Desktop.
Causal Attention Layer Tensorflow.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//--------------------------------------------------- | |
//setup | |
var dModel = 256; //dimension of embedding | |
var INPUTSIZE = 1024; //length of input | |
var BLOCKS = 3; | |
var LEARNING_RATE = 0.003; | |
var OPTIMIZER = tf.train.adam(LEARNING_RATE, beta_1 = 0.85, beta_2 = 0.9, epsilon = 1e-9); | |
var LOSS = 'categoricalCrossentropy'; | |
var INITIALIZER = 'GlorotUniform'; | |
/****************************************************************************** | |
* tensorflow.js lambda layer | |
* written by twitter.com/benjaminwegener | |
* license: MIT | |
*/ | |
class lambdaLayer extends tf.layers.Layer { | |
constructor(config) { | |
super(config); | |
if (config.name === undefined) { | |
config.name = ((+new Date) * Math.random()).toString(36); //random name from timestamp in case name hasn't been set | |
} | |
this.name = config.name; | |
this.lambdaFunction = config.lambdaFunction; | |
this.lambdaOutputShape = config.lambdaOutputShape; | |
} | |
call(input) { | |
return tf.tidy(() => { | |
let result = null; | |
eval(this.lambdaFunction); | |
return result; | |
}); | |
} | |
computeOutputShape(inputShape) { | |
if (this.lambdaOutputShape === undefined) { //if no outputshape provided, try to set as inputshape | |
return inputShape[0]; | |
} else { | |
return this.lambdaOutputShape; | |
} | |
} | |
getConfig() { | |
const config = super.getConfig(); | |
Object.assign(config, { | |
lambdaFunction: this.lambdaFunction, | |
lambdaOutputShape: this.lambdaOutputShape | |
}); | |
return config; | |
} | |
static get className() { | |
return 'lambdaLayer'; | |
} | |
} | |
tf.serialization.registerClass(lambdaLayer); | |
//--------------------------------------------------- | |
//Attention Mechanism with causal masking as seperate model: | |
let queries = tf.input({name: 'queries', shape: [INPUTSIZE, dModel]}); | |
let keys = tf.input({name: 'keys', shape: [INPUTSIZE, dModel]}); | |
let values = tf.input({name: 'values', shape: [INPUTSIZE, dModel]}); | |
let outputs = new lambdaLayer({name: 'causalAtt', lambdaFunction: ` | |
result = tf.matMul(input[0], input[1], false, true); | |
result = tf.mul(result, tf.sqrt(tf.scalar(dModel))); | |
let causalMask = Array(result.shape[1]).fill().map(() => Array(result.shape[2]).fill(-1e9)); | |
for (let h = 0; h < result.shape[1]; h++){ | |
for (let w = 0; w < result.shape[2]; w++){ | |
if (h <= w){ | |
causalMask[w][h] = 0; | |
} | |
} | |
} | |
result = tf.add(result, tf.tensor(causalMask)); | |
result = tf.softmax(result, -1); | |
result = tf.matMul(result, input[2]); | |
`, lambdaOutputShape: [values.shape]}).apply([queries, keys, values]); | |
attention_model = tf.model({inputs: [queries, keys, values], outputs: outputs}); | |
attention_model.compile({optimizer: OPTIMIZER, loss: LOSS}); //Don't know if this is neccessary | |
attention_model.summary(); | |
//--------------------------------------------------- | |
//use it in your model as follows: | |
queries = tf.layers.dense({units: dModel}).apply(x); | |
keys = tf.layers.dense({units: dModel}).apply(x); | |
values = tf.layers.dense({units: dModel}).apply(x); | |
x = attention_model.apply([queries, keys, values]); //HERE THE MODEL FOR ATTENTION IS USED |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment