Skip to content

Instantly share code, notes, and snippets.

@rezonn
Last active March 4, 2023 11:58
Show Gist options
  • Select an option

  • Save rezonn/34b1948a405720223ab75478b7ca04b4 to your computer and use it in GitHub Desktop.

Select an option

Save rezonn/34b1948a405720223ab75478b7ca04b4 to your computer and use it in GitHub Desktop.
MultiHeadAttention tfjs
function maskedSoftmax(tensor, mask, ax) {
return tf.tidy(()=>{
var rv = tf.exp(tensor);
rv = tf.mul(rv, mask);
rv = tf.add(rv, 1e-9);
var sm = tf.sum(rv, ax);
rv.print();
sm.print();
//sm = tf.add(sm, 1e-7);
sm = tf.expandDims(sm, ax)
rv = tf.div(rv, sm);
return rv;
})
}
class MultiHeadAttention extends tf.layers.Layer {
constructor({num_heads, key_dim}) {
super({});
Object.assign(this, arguments[0]);
}
getClassName() {
return 'MultiHeadAttention';
}
computeOutputShape(inputShape) {
return inputShape[0];
}
build(queryShape) {
var qShape = queryShape[0];
var k_shape = [qShape[qShape.length-1], this.num_heads, this.key_dim];
var b_shape = [this.num_heads, this.key_dim];
this.q_kernel = this.addWeight('q_kernel', k_shape, this.dtype, tf.initializers.zeros());
this.q_bias = this.addWeight('q_bias', b_shape, this.dtype, tf.initializers.zeros());
this.k_kernel = this.addWeight('k_kernel', k_shape, this.dtype, tf.initializers.zeros());
this.k_bias = this.addWeight('k_bias', b_shape, this.dtype, tf.initializers.zeros());
this.v_kernel = this.addWeight('v_kernel', k_shape, this.dtype, tf.initializers.zeros());
this.v_bias = this.addWeight('v_bias', b_shape, this.dtype, tf.initializers.zeros());
var k_shape2 = [this.num_heads, this.key_dim, qShape[qShape.length-1]];
var b_shape2 = [qShape[qShape.length-1]];
this.o_kernel = this.addWeight('o_kernel', k_shape2, this.dtype, tf.initializers.zeros());
this.o_bias = this.addWeight('o_bias', b_shape2, this.dtype, tf.initializers.zeros());
}
call([query, value, key, mask]) {
if (!key) key = value;
var query = tf.add(tf.einsum("abc,cde->abde", query, this.q_kernel.read()), this.q_bias.read())
var key = tf.add(tf.einsum("abc,cde->abde", key, this.k_kernel.read()), this.k_bias.read())
var value = tf.add(tf.einsum("abc,cde->abde", value, this.v_kernel.read()), this.v_bias.read())
var query = tf.mul(query, tf.scalar(1.0 / Math.sqrt(this.key_dim)))
var scores= tf.einsum("aecd,abcd->acbe", key, query)
if (mask) {
var mask2 = mask;
for (var i=0;i<scores.shape.length-mask.shape.length;i++) {
mask2 = tf.expandDims(mask2, -3);
}
scores = maskedSoftmax(scores, mask2, 3)
}
else {
scores = tf.softmax(scores);
}
var output = tf.einsum("acbe,aecd->abcd", scores, value);
output = tf.add(tf.einsum("abcd,cde->abe", output, this.o_kernel.read()), this.o_bias.read());
return output;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment