Skip to content

Instantly share code, notes, and snippets.

@keicoon
Created July 31, 2018 00:36
Show Gist options
  • Save keicoon/c0e6fa201fa2c7f56babf5807dca276c to your computer and use it in GitHub Desktop.
Save keicoon/c0e6fa201fa2c7f56babf5807dca276c to your computer and use it in GitHub Desktop.
Defined general `matMul` function in tf-js
function matMul(a, b, transposeA = false, transposeB = false) {
if (a.shape.length == b.shape.length
&& (transposeA ? a.shape[a.shape.length - 2] : a.shape[a.shape.length - 1]
== transposeB ? b.shape[b.shape.length - 1] : b.shape[b.shape.length - 2])) {
const shapeA = a.shape;
const shapeB = b.shape;
const arrA = a.dataSync();
const arrB = b.dataSync();
let arrMatMul2D = [];
function matMul2D(boundA, boundB, shapeIdx) {
if (shapeIdx < shapeA.length - 2) {
for (let i = 0; i < shapeA[shapeIdx]; i++) {
const countA = (boundA[1] - boundA[0]) / shapeA[shapeIdx];
const countB = (boundB[1] - boundB[0]) / shapeB[shapeIdx];
matMul2D(
[
boundA[0] + (countA * i),
boundA[0] + (countA * (i + 1))
],
[
boundB[0] + (countB * i),
boundB[0] + (countB * (i + 1))
], shapeIdx + 1);
}
} else {
let matrix2dA = tf.tensor2d(arrA.slice(boundA[0], boundA[1]), [shapeA[shapeIdx], shapeA[shapeIdx + 1]]);
let matrix2dB = tf.tensor2d(arrB.slice(boundB[0], boundB[1]), [shapeB[shapeIdx], shapeB[shapeIdx + 1]]);
let arrResult = Array.from(matrix2dA.matMul(matrix2dB, transposeA, transposeB).dataSync())
Array.prototype.push.apply(arrMatMul2D, arrResult);
}
}
function shape() {
let arr = [];
for (let i = 0; i < shapeA.length - 2; i++) {
arr.push(shapeA[i]);
}
arr.push(
transposeA ? a.shape[a.shape.length - 1] : a.shape[a.shape.length - 2],
transposeB ? b.shape[b.shape.length - 2] : b.shape[b.shape.length - 1]
);
return arr;
}
const arrShape = shape();
matMul2D([0, arrA.length], [0, arrB.length], 0);
return tf.tensor(arrMatMul2D, arrShape);
} else {
throw new Error('matMul shape not valid', 'a:', a.shape, 'b:', b.shape);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment