Created
October 30, 2019 20:37
-
-
Save janosh/dfce72a9722d6b3ecff1eebcb5cb5df8 to your computer and use it in GitHub Desktop.
Dabbling with TensorFlowJS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import * as tf from '@tensorflow/tfjs' | |
export const tfMultivariateNormal = (mu, sigma) => { | |
[mu, sigma] = [tf.tensor(mu), tf.tensor(sigma)] | |
const [dim] = mu.shape | |
if (!sigma.shape.every(d => d === dim)) | |
throw new Error(`dimension mismatch in tfMultivariateNormal()`) | |
const Z = ((2 * Math.PI) ** dim * tfDet(sigma).arraySync()) ** -0.5 | |
return x => { | |
x = tf.tensor(x) | |
return ( | |
Z * | |
Math.exp( | |
-0.5 * | |
tf.dot(x.sub(mu), tf.dot(tfInverse(sigma), x.sub(mu))).arraySync() | |
) | |
) | |
} | |
} | |
export const tfMultivariateNormalDiag = (mu, sigma) => { | |
[mu, sigma] = [tf.tensor(mu), tf.tensor(sigma)] | |
const [dim] = mu.shape | |
if (!sigma.shape.every(d => d === dim)) | |
throw new Error(`dimension mismatch in multivariateNormalDiag()`) | |
const Z = ((2 * Math.PI) ** dim * sigma.prod()) ** -0.5 | |
return x => { | |
x = tf.tensor(x) | |
const exp = x | |
.sub(mu) | |
.mul(tf.dot(sigma.reciprocal(), x.sub(mu))) | |
.arraySync() | |
return Z * Math.exp(-0.5 * exp) | |
} | |
} | |
// Calculate the determinant of a matrix or matrices. | |
// The last two dimensions are assumed to be square matrices. | |
// Adapted from https://github.com/tensorflow/tfjs/issues/1516#issuecomment-545120751. | |
const tfDet = tnsr => { | |
const [n, m, ...dims] = tnsr.shape.reverse() | |
if (m !== n) throw new Error(`det(): Received non-square matrix.`) | |
const mats = tnsr.reshape([-1, n, n]).unstack() | |
const dets = mats.map(mat => { | |
const r = tf.linalg.qr(mat)[1] | |
const diag_r = r.flatten().stridedSlice([0], [n * n], [n + 1]) | |
const det_r = diag_r.prod() | |
// q is product of n Householder reflections, i.e. det(q) = (-1)^n | |
const det_q = n % 2 === 0 ? 1 : -1 | |
return tf.mul(det_q, det_r) | |
}) | |
return tf.stack(dets).reshape(dims) | |
} | |
// Tensor inversion using the matrix adjoint method. | |
// Adapted from https://stackoverflow.com/a/51808271. | |
export const tfInverse = tnsr => { | |
return tf.tidy(() => { | |
const d = tfDet(tnsr).arraySync() | |
if (d === 0) return | |
const [r] = tnsr.shape | |
const rows = [...Array(r).keys()] | |
const dets = [] | |
for (let i = 0; i < r; i++) { | |
for (let j = 0; j < r; j++) { | |
const sub_m = tnsr.gather( | |
tf.tensor1d(rows.filter(e => e !== i), `int32`) | |
) | |
let sli | |
if (j === 0) { | |
sli = sub_m.slice([0, 1], [r - 1, r - 1]) | |
} else if (j === r - 1) { | |
sli = sub_m.slice([0, 0], [r - 1, r - 1]) | |
} else { | |
const [a, b, c] = tf.split(sub_m, [j, 1, r - (j + 1)], 1) | |
sli = tf.concat([a, c], 1) | |
} | |
dets.push((-1) ** (i + j) * tfDet(sli).arraySync()) | |
} | |
} | |
const com = tf.tensor2d(dets, [r, r]) | |
const tr_com = com.transpose() | |
const inv_m = tr_com.div(tf.scalar(d)) | |
return inv_m | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment