Skip to content

Instantly share code, notes, and snippets.

@jdmichaud
Created October 2, 2018 16:49
Show Gist options
  • Select an option

  • Save jdmichaud/f2cd0a52b9838d47bb2c65392985232d to your computer and use it in GitHub Desktop.

Select an option

Save jdmichaud/f2cd0a52b9838d47bb2c65392985232d to your computer and use it in GitHub Desktop.
Tensorflow tensor cross product shim
function cross(rhs) {
const lhs = this;
if (lhs.shape[1] !== 1 || rhs.shape[1] !== 1 || lhs.shape[0] !== 3 || rhs.shape[0] !== 3) {
throw new Error(`cross product only implemented for vector of shape (3, 1)`);
}
// Cross product
const u = lhs.dataSync();
const v = rhs.dataSync();
return tf.tensor([
[u[1] * v[2] - u[2] * v[1]],
[u[2] * v[0] - u[0] * v[2]],
[u[0] * v[1] - u[1] * v[0]]
]);
}
if (tf !== undefined && tf.Tensor !== undefined) {
if (tf.Tensor.prototype === undefined) {
tf.Tensor.prototype = {};
}
tf.Tensor.prototype.cross = cross;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment