Last active
April 7, 2020 16:50
-
-
Save umbra-scientia/baea386865ff410180ff30aaa5741c6b to your computer and use it in GitHub Desktop.
symbolic differential calculus
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
var mob = { | |
num: function(v) { return {op: 'num', value: v, args: []}; }, | |
sym: function(s) { return {op: 'sym', symbol: s, args: []}; }, | |
add: function(x, y) { return {op: 'add', args: [x, y]}; }, | |
mul: function(x, y) { return {op: 'mul', args: [x, y]}; }, | |
sum: function(v) { return {op: 'add', args: v}; }, | |
prod: function(v) { return {op: 'mul', args: v}; }, | |
neg: function(x) { return {op: 'neg', args: [x]}; }, | |
recip: function(x) { return {op: 'recip', args: [x]}; }, | |
sub: function(x, y) { return mob.add(x, mob.neg(y)); }, | |
div: function(x, y) { return mob.mul(x, mob.recip(y)); }, | |
sqrt: function(x) { return {op: 'sqrt', args: [x]}; }, | |
erf: function(x) { return {op: 'erf', args: [x]}; }, | |
exp: function(x) { return {op: 'exp', args: [x]}; }, log: function(x) { return {op: 'log', args: [x]}; }, | |
sin: function(x) { return {op: 'sin', args: [x]}; }, cos: function(x) { return {op: 'cos', args: [x]}; }, | |
gelu: function(x) { return {op: 'gelu', args: [x]}; }, | |
relu: function(x) { return {op: 'relu', args: [x]}; }, | |
tanh: function(x) { return {op: 'tanh', args: [x]}; }, | |
sigmoid: function(x) { return {op: 'sigmoid', args: [x]}; }, | |
sqrt: function(x) { return {op: 'sqrt', args: [x]}; }, | |
fun: function(f, x, y) { return {op: f, args: y ? [x, y] : [x]}; }, | |
eq: function(x, y) { | |
if (x.op != y.op) return false; | |
if (x.op == 'num') return x.value == y.value; | |
if (x.op == 'sym') return x.symbol == y.symbol; | |
if (x.args.length != y.args.length) return false; | |
for(var i=0;i<x.args.length;i++) { | |
if (!mob.eq(x.args[i], y.args[i])) return false; | |
} | |
return true; | |
}, | |
diff: function(f, x) { | |
if (mob.eq(f, x)) return mob.num(1); | |
if (f.op == 'sym') return mob.num(0); | |
if (f.op == 'num') return mob.num(0); | |
var r = false; | |
if (f.op == 'add') { | |
var sum = []; | |
for(var i=0;i<f.args.length;i++) { | |
sum.push(mob.diff(f.args[i], x)); | |
} | |
r = mob.sum(sum); | |
} | |
else if (f.op == 'mul') { | |
var sum = []; | |
for(var i=0;i<f.args.length;i++) { | |
var term = [mob.diff(f.args[i], x)]; | |
for(var j=0;j<f.args.length;j++) { | |
if (i == j) continue; | |
term.push(f.args[j]); | |
} | |
sum.push(mob.prod(term)); | |
} | |
r = mob.sum(sum); | |
} | |
else if (f.op == "exp") r = mob.mul(f, mob.diff(f.args[0], x)); | |
else if (f.op == "log") r = mob.div(mob.diff(f.args[0], x), f.args[0]); | |
else if (f.op == "sqrt") r = mob.prod([mob.num(0.5), mob.diff(f.args[0], x), mob.recip(f)]); | |
else if (f.op == "erf") r = mob.prod([mob.num(1.128379167095512), mob.diff(f.args[0], x), mob.exp(mob.neg(mob.mul(f.args[0], f.args[0])))]); | |
else if (f.op == "sin") r = mob.mul(mob.cos(f.args[0]), mob.diff(f.args[0], x)); | |
else if (f.op == "cos") r = mob.neg(mob.mul(mob.sin(f.args[0]), mob.diff(f.args[0], x))); | |
else if (f.op == "neg") r = mob.neg(mob.diff(f.args[0], x)); | |
else r = mob.fun(f.op+'\'', f.args[0], f.args[1]); | |
return r; | |
}, | |
reduce: function(x) { | |
if (x.op == 'add') { | |
var sum = []; | |
var neg_sum = []; | |
var num = 0; | |
var q = x.args.slice(), qi = []; | |
for(var i=0;i<q.length;i++) q[i] = mob.reduce(q[i]); | |
while (q.length || qi.length) { | |
var isNeg = !q.length; | |
var arg = isNeg ? qi.shift() : q.shift(); | |
if (arg.op == 'num') { | |
if (isNeg) num -= arg.value; | |
else num += arg.value; | |
} else if (arg.op == 'neg') { | |
if (isNeg) q.push(arg.args[0]); | |
else qi.push(arg.args[0]); | |
} else if (arg.op == 'add') { | |
if (isNeg) qi = qi.concat(arg.args); | |
else q = q.concat(arg.args); | |
} else { | |
if (isNeg) neg_sum.push(arg); | |
else sum.push(arg); | |
} | |
} | |
var const_term = mob.num(num); | |
var pos_term = sum.length ? ((sum.length == 1) ? sum[0] : mob.sum(sum)) : false; | |
var neg_term = neg_sum.length ? ((neg_sum.length == 1) ? neg_sum[0] : mob.sum(neg_sum)) : false; | |
var result = false; | |
if (pos_term && neg_term) result = mob.sub(pos_term, neg_term) | |
else if (pos_term) result = pos_term; | |
else if (neg_term) result = mob.neg(neg_term); | |
if (!result) result = const_term; | |
else if (num != 0) result = mob.add(const_term, result); | |
return result; | |
} | |
if (x.op == 'mul') { | |
var prod = []; | |
var inv_prod = []; | |
var num = 1; | |
var q = x.args.slice(), qi = []; | |
for(var i=0;i<q.length;i++) q[i] = mob.reduce(q[i]); | |
while (q.length || qi.length) { | |
var isNeg = !q.length; | |
var arg = isNeg ? qi.shift() : q.shift(); | |
if (arg.op == 'num') { | |
if (isNeg) num /= arg.value; | |
else num *= arg.value; | |
} else if (arg.op == 'recip') { | |
if (isNeg) q.push(arg.args[0]); | |
else qi.push(arg.args[0]); | |
} else if (arg.op == 'mul') { | |
if (isNeg) qi = qi.concat(arg.args); | |
else q = q.concat(arg.args); | |
} else { | |
if (isNeg) inv_prod.push(arg); | |
else prod.push(arg); | |
} | |
} | |
var const_term = mob.num(num); | |
var pos_term = prod.length ? ((prod.length == 1) ? prod[0] : mob.prod(prod)) : false; | |
var inv_term = inv_prod.length ? ((inv_prod.length == 1) ? inv_prod[0] : mob.prod(inv_prod)) : false; | |
var result = false; | |
if (pos_term && inv_term) result = mob.div(pos_term, inv_term) | |
else if (pos_term) result = pos_term; | |
else if (inv_term) result = mob.recip(inv_term); | |
if (!result) result = const_term; | |
else if (num != 1) result = mob.mul(const_term, result); | |
return result; | |
} | |
return x; | |
}, | |
subst: function(x, syms) { | |
if (x.op == 'num') return x; | |
if (x.op == 'sym') { | |
var s = syms[x.symbol]; | |
if (typeof(s) == "string") { | |
return mob.sym(s); | |
} | |
if (s) return s; | |
return x; | |
} | |
var r = {op: x.op, args: []}; | |
for(var i=0;i<x.args.length;i++) { | |
r.args.push(mob.subst(x.args[i], syms)); | |
} | |
if (syms[r.op]) { | |
if (typeof(syms[r.op]) == "string") { | |
r.op = syms[r.op]; | |
} else { | |
r = syms[r.op].apply(r.op, r.args); | |
} | |
} | |
return r; | |
}, | |
str: function(x) { | |
if (x.op == 'num') return ""+x.value; | |
if (x.op == 'sym') return ""+x.symbol; | |
var oper = ','; | |
if (x.op == 'add') oper = '+'; | |
if (x.op == 'mul') oper = '*'; | |
var s = "("; | |
for(var i=0;i<x.args.length;i++) { | |
if (s != "(") s += oper; | |
s += mob.str(x.args[i]); | |
} | |
s += ")"; | |
if (x.op == 'neg') s = "-"+s; | |
else if (x.op == 'recip') s = "1/"+s; | |
else if (oper == ',') s = x.op+s; | |
return s; | |
}, | |
compile: function(x, args) { | |
if (args === undefined) args = []; | |
if (args) { | |
var mobjs_last_compiled_function_ = false; | |
eval("mobjs_last_compiled_function_ = function("+args.join(",")+"){return "+mob.compile(x, false)+";}"); | |
return mobjs_last_compiled_function_; | |
} | |
if (x.op == 'num') return ""+x.value; | |
if (x.op == 'sym') return ""+x.symbol; | |
var oper = ','; | |
if (x.op == 'add') oper = '+'; | |
if (x.op == 'mul') oper = '*'; | |
var s = "("; | |
for(var i=0;i<x.args.length;i++) { | |
if (s != "(") s += oper; | |
s += mob.compile(x.args[i], false); | |
} | |
s += ")"; | |
if (x.op == 'neg') s = "-"+s; | |
else if (x.op == 'recip') s = "1/"+s; | |
else if (oper == ',') { | |
var fn = x.op.split("'").join("_d"); | |
if (mob.numer[fn]) { | |
s = "mob.numer."+fn+s; | |
} else if (Math[fn]) { | |
s = "Math."+fn+s; | |
} else { | |
s = fn+s; | |
} | |
} | |
return s; | |
}, | |
numer: { | |
exp: Math.exp, | |
log: Math.log, | |
tanh: Math.tanh, | |
tanh_d: function(x) {x = 1.0 / Math.cosh(x); return x * x;}, | |
sqrt: Math.sqrt, | |
sin: Math.sin, cos: Math.cos, | |
sinh: Math.sinh, cosh: Math.cosh, | |
asin: Math.asin, acos: Math.acos, | |
sqrt: Math.sqrt, | |
erf: function(x) { | |
var s = (x < 0) ? -1 : 1; | |
var m = [0.3275911, 0.254829592, -0.284496736, 1.421413741, -1.453152027, 1.061405429]; | |
var y = 1.0 / (1.0 + m[0]*x); | |
return s * (1.0 - Math.exp(-x*x) * y*(m[1] + y*(m[2] + y*(m[3] + y*(m[4] + y*m[5]))))); | |
}, | |
erf_d: function(x) {return 1.128379167095512 * Math.exp(-x*x);}, | |
gelu: function(x) {return x * (0.5 + 0.5*mob.numer.erf(x * 0.7071067811865475));}, | |
gelu_d: function(x) { | |
var a = (1.0 + mob.numer.erf(x * 0.7071067811865475)) / 2.0; | |
var b = Math.exp(-0.5*x*x) * x * 0.3989422804014327; | |
return a + b; | |
}, | |
relu: function(x) {return (x < 0) ? 0 : x;}, | |
relu_d: function(x) {return (x < 0) ? 0 : 1;}, | |
sigmoid: function(x) {return 1.0 / (1.0 + Math.exp(x));}, | |
sigmoid_d: function(x) {var ex = Math.exp(x); return -1.0 / (2.0 + ex + 1.0/ex);} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment