Skip to content

Instantly share code, notes, and snippets.

@umbra-scientia
Last active April 7, 2020 16:50
Show Gist options
  • Save umbra-scientia/baea386865ff410180ff30aaa5741c6b to your computer and use it in GitHub Desktop.
Save umbra-scientia/baea386865ff410180ff30aaa5741c6b to your computer and use it in GitHub Desktop.
symbolic differential calculus
var mob = {
num: function(v) { return {op: 'num', value: v, args: []}; },
sym: function(s) { return {op: 'sym', symbol: s, args: []}; },
add: function(x, y) { return {op: 'add', args: [x, y]}; },
mul: function(x, y) { return {op: 'mul', args: [x, y]}; },
sum: function(v) { return {op: 'add', args: v}; },
prod: function(v) { return {op: 'mul', args: v}; },
neg: function(x) { return {op: 'neg', args: [x]}; },
recip: function(x) { return {op: 'recip', args: [x]}; },
sub: function(x, y) { return mob.add(x, mob.neg(y)); },
div: function(x, y) { return mob.mul(x, mob.recip(y)); },
sqrt: function(x) { return {op: 'sqrt', args: [x]}; },
erf: function(x) { return {op: 'erf', args: [x]}; },
exp: function(x) { return {op: 'exp', args: [x]}; }, log: function(x) { return {op: 'log', args: [x]}; },
sin: function(x) { return {op: 'sin', args: [x]}; }, cos: function(x) { return {op: 'cos', args: [x]}; },
gelu: function(x) { return {op: 'gelu', args: [x]}; },
relu: function(x) { return {op: 'relu', args: [x]}; },
tanh: function(x) { return {op: 'tanh', args: [x]}; },
sigmoid: function(x) { return {op: 'sigmoid', args: [x]}; },
sqrt: function(x) { return {op: 'sqrt', args: [x]}; },
fun: function(f, x, y) { return {op: f, args: y ? [x, y] : [x]}; },
eq: function(x, y) {
if (x.op != y.op) return false;
if (x.op == 'num') return x.value == y.value;
if (x.op == 'sym') return x.symbol == y.symbol;
if (x.args.length != y.args.length) return false;
for(var i=0;i<x.args.length;i++) {
if (!mob.eq(x.args[i], y.args[i])) return false;
}
return true;
},
diff: function(f, x) {
if (mob.eq(f, x)) return mob.num(1);
if (f.op == 'sym') return mob.num(0);
if (f.op == 'num') return mob.num(0);
var r = false;
if (f.op == 'add') {
var sum = [];
for(var i=0;i<f.args.length;i++) {
sum.push(mob.diff(f.args[i], x));
}
r = mob.sum(sum);
}
else if (f.op == 'mul') {
var sum = [];
for(var i=0;i<f.args.length;i++) {
var term = [mob.diff(f.args[i], x)];
for(var j=0;j<f.args.length;j++) {
if (i == j) continue;
term.push(f.args[j]);
}
sum.push(mob.prod(term));
}
r = mob.sum(sum);
}
else if (f.op == "exp") r = mob.mul(f, mob.diff(f.args[0], x));
else if (f.op == "log") r = mob.div(mob.diff(f.args[0], x), f.args[0]);
else if (f.op == "sqrt") r = mob.prod([mob.num(0.5), mob.diff(f.args[0], x), mob.recip(f)]);
else if (f.op == "erf") r = mob.prod([mob.num(1.128379167095512), mob.diff(f.args[0], x), mob.exp(mob.neg(mob.mul(f.args[0], f.args[0])))]);
else if (f.op == "sin") r = mob.mul(mob.cos(f.args[0]), mob.diff(f.args[0], x));
else if (f.op == "cos") r = mob.neg(mob.mul(mob.sin(f.args[0]), mob.diff(f.args[0], x)));
else if (f.op == "neg") r = mob.neg(mob.diff(f.args[0], x));
else r = mob.fun(f.op+'\'', f.args[0], f.args[1]);
return r;
},
reduce: function(x) {
if (x.op == 'add') {
var sum = [];
var neg_sum = [];
var num = 0;
var q = x.args.slice(), qi = [];
for(var i=0;i<q.length;i++) q[i] = mob.reduce(q[i]);
while (q.length || qi.length) {
var isNeg = !q.length;
var arg = isNeg ? qi.shift() : q.shift();
if (arg.op == 'num') {
if (isNeg) num -= arg.value;
else num += arg.value;
} else if (arg.op == 'neg') {
if (isNeg) q.push(arg.args[0]);
else qi.push(arg.args[0]);
} else if (arg.op == 'add') {
if (isNeg) qi = qi.concat(arg.args);
else q = q.concat(arg.args);
} else {
if (isNeg) neg_sum.push(arg);
else sum.push(arg);
}
}
var const_term = mob.num(num);
var pos_term = sum.length ? ((sum.length == 1) ? sum[0] : mob.sum(sum)) : false;
var neg_term = neg_sum.length ? ((neg_sum.length == 1) ? neg_sum[0] : mob.sum(neg_sum)) : false;
var result = false;
if (pos_term && neg_term) result = mob.sub(pos_term, neg_term)
else if (pos_term) result = pos_term;
else if (neg_term) result = mob.neg(neg_term);
if (!result) result = const_term;
else if (num != 0) result = mob.add(const_term, result);
return result;
}
if (x.op == 'mul') {
var prod = [];
var inv_prod = [];
var num = 1;
var q = x.args.slice(), qi = [];
for(var i=0;i<q.length;i++) q[i] = mob.reduce(q[i]);
while (q.length || qi.length) {
var isNeg = !q.length;
var arg = isNeg ? qi.shift() : q.shift();
if (arg.op == 'num') {
if (isNeg) num /= arg.value;
else num *= arg.value;
} else if (arg.op == 'recip') {
if (isNeg) q.push(arg.args[0]);
else qi.push(arg.args[0]);
} else if (arg.op == 'mul') {
if (isNeg) qi = qi.concat(arg.args);
else q = q.concat(arg.args);
} else {
if (isNeg) inv_prod.push(arg);
else prod.push(arg);
}
}
var const_term = mob.num(num);
var pos_term = prod.length ? ((prod.length == 1) ? prod[0] : mob.prod(prod)) : false;
var inv_term = inv_prod.length ? ((inv_prod.length == 1) ? inv_prod[0] : mob.prod(inv_prod)) : false;
var result = false;
if (pos_term && inv_term) result = mob.div(pos_term, inv_term)
else if (pos_term) result = pos_term;
else if (inv_term) result = mob.recip(inv_term);
if (!result) result = const_term;
else if (num != 1) result = mob.mul(const_term, result);
return result;
}
return x;
},
subst: function(x, syms) {
if (x.op == 'num') return x;
if (x.op == 'sym') {
var s = syms[x.symbol];
if (typeof(s) == "string") {
return mob.sym(s);
}
if (s) return s;
return x;
}
var r = {op: x.op, args: []};
for(var i=0;i<x.args.length;i++) {
r.args.push(mob.subst(x.args[i], syms));
}
if (syms[r.op]) {
if (typeof(syms[r.op]) == "string") {
r.op = syms[r.op];
} else {
r = syms[r.op].apply(r.op, r.args);
}
}
return r;
},
str: function(x) {
if (x.op == 'num') return ""+x.value;
if (x.op == 'sym') return ""+x.symbol;
var oper = ',';
if (x.op == 'add') oper = '+';
if (x.op == 'mul') oper = '*';
var s = "(";
for(var i=0;i<x.args.length;i++) {
if (s != "(") s += oper;
s += mob.str(x.args[i]);
}
s += ")";
if (x.op == 'neg') s = "-"+s;
else if (x.op == 'recip') s = "1/"+s;
else if (oper == ',') s = x.op+s;
return s;
},
compile: function(x, args) {
if (args === undefined) args = [];
if (args) {
var mobjs_last_compiled_function_ = false;
eval("mobjs_last_compiled_function_ = function("+args.join(",")+"){return "+mob.compile(x, false)+";}");
return mobjs_last_compiled_function_;
}
if (x.op == 'num') return ""+x.value;
if (x.op == 'sym') return ""+x.symbol;
var oper = ',';
if (x.op == 'add') oper = '+';
if (x.op == 'mul') oper = '*';
var s = "(";
for(var i=0;i<x.args.length;i++) {
if (s != "(") s += oper;
s += mob.compile(x.args[i], false);
}
s += ")";
if (x.op == 'neg') s = "-"+s;
else if (x.op == 'recip') s = "1/"+s;
else if (oper == ',') {
var fn = x.op.split("'").join("_d");
if (mob.numer[fn]) {
s = "mob.numer."+fn+s;
} else if (Math[fn]) {
s = "Math."+fn+s;
} else {
s = fn+s;
}
}
return s;
},
numer: {
exp: Math.exp,
log: Math.log,
tanh: Math.tanh,
tanh_d: function(x) {x = 1.0 / Math.cosh(x); return x * x;},
sqrt: Math.sqrt,
sin: Math.sin, cos: Math.cos,
sinh: Math.sinh, cosh: Math.cosh,
asin: Math.asin, acos: Math.acos,
sqrt: Math.sqrt,
erf: function(x) {
var s = (x < 0) ? -1 : 1;
var m = [0.3275911, 0.254829592, -0.284496736, 1.421413741, -1.453152027, 1.061405429];
var y = 1.0 / (1.0 + m[0]*x);
return s * (1.0 - Math.exp(-x*x) * y*(m[1] + y*(m[2] + y*(m[3] + y*(m[4] + y*m[5])))));
},
erf_d: function(x) {return 1.128379167095512 * Math.exp(-x*x);},
gelu: function(x) {return x * (0.5 + 0.5*mob.numer.erf(x * 0.7071067811865475));},
gelu_d: function(x) {
var a = (1.0 + mob.numer.erf(x * 0.7071067811865475)) / 2.0;
var b = Math.exp(-0.5*x*x) * x * 0.3989422804014327;
return a + b;
},
relu: function(x) {return (x < 0) ? 0 : x;},
relu_d: function(x) {return (x < 0) ? 0 : 1;},
sigmoid: function(x) {return 1.0 / (1.0 + Math.exp(x));},
sigmoid_d: function(x) {var ex = Math.exp(x); return -1.0 / (2.0 + ex + 1.0/ex);}
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment