Skip to content

Instantly share code, notes, and snippets.

@janosh
Created October 30, 2019 20:37
Show Gist options
  • Save janosh/dfce72a9722d6b3ecff1eebcb5cb5df8 to your computer and use it in GitHub Desktop.
Save janosh/dfce72a9722d6b3ecff1eebcb5cb5df8 to your computer and use it in GitHub Desktop.
Dabbling with TensorFlowJS
import * as tf from '@tensorflow/tfjs'
export const tfMultivariateNormal = (mu, sigma) => {
[mu, sigma] = [tf.tensor(mu), tf.tensor(sigma)]
const [dim] = mu.shape
if (!sigma.shape.every(d => d === dim))
throw new Error(`dimension mismatch in tfMultivariateNormal()`)
const Z = ((2 * Math.PI) ** dim * tfDet(sigma).arraySync()) ** -0.5
return x => {
x = tf.tensor(x)
return (
Z *
Math.exp(
-0.5 *
tf.dot(x.sub(mu), tf.dot(tfInverse(sigma), x.sub(mu))).arraySync()
)
)
}
}
export const tfMultivariateNormalDiag = (mu, sigma) => {
[mu, sigma] = [tf.tensor(mu), tf.tensor(sigma)]
const [dim] = mu.shape
if (!sigma.shape.every(d => d === dim))
throw new Error(`dimension mismatch in multivariateNormalDiag()`)
const Z = ((2 * Math.PI) ** dim * sigma.prod()) ** -0.5
return x => {
x = tf.tensor(x)
const exp = x
.sub(mu)
.mul(tf.dot(sigma.reciprocal(), x.sub(mu)))
.arraySync()
return Z * Math.exp(-0.5 * exp)
}
}
// Calculate the determinant of a matrix or matrices.
// The last two dimensions are assumed to be square matrices.
// Adapted from https://github.com/tensorflow/tfjs/issues/1516#issuecomment-545120751.
const tfDet = tnsr => {
const [n, m, ...dims] = tnsr.shape.reverse()
if (m !== n) throw new Error(`det(): Received non-square matrix.`)
const mats = tnsr.reshape([-1, n, n]).unstack()
const dets = mats.map(mat => {
const r = tf.linalg.qr(mat)[1]
const diag_r = r.flatten().stridedSlice([0], [n * n], [n + 1])
const det_r = diag_r.prod()
// q is product of n Householder reflections, i.e. det(q) = (-1)^n
const det_q = n % 2 === 0 ? 1 : -1
return tf.mul(det_q, det_r)
})
return tf.stack(dets).reshape(dims)
}
// Tensor inversion using the matrix adjoint method.
// Adapted from https://stackoverflow.com/a/51808271.
export const tfInverse = tnsr => {
return tf.tidy(() => {
const d = tfDet(tnsr).arraySync()
if (d === 0) return
const [r] = tnsr.shape
const rows = [...Array(r).keys()]
const dets = []
for (let i = 0; i < r; i++) {
for (let j = 0; j < r; j++) {
const sub_m = tnsr.gather(
tf.tensor1d(rows.filter(e => e !== i), `int32`)
)
let sli
if (j === 0) {
sli = sub_m.slice([0, 1], [r - 1, r - 1])
} else if (j === r - 1) {
sli = sub_m.slice([0, 0], [r - 1, r - 1])
} else {
const [a, b, c] = tf.split(sub_m, [j, 1, r - (j + 1)], 1)
sli = tf.concat([a, c], 1)
}
dets.push((-1) ** (i + j) * tfDet(sli).arraySync())
}
}
const com = tf.tensor2d(dets, [r, r])
const tr_com = com.transpose()
const inv_m = tr_com.div(tf.scalar(d))
return inv_m
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment