Last active
March 24, 2017 05:24
-
-
Save jeffdonahue/84eab3b74db5309da871a5abb78bdde6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.tensor as T | |
low = 'float32' | |
high = 'float64' | |
dtype_low, dtype_high = [T.TensorType(f, (False,)) for f in [low, high]] | |
pred32, target32 = dtype_low('p'), dtype_low('t') | |
pred64, target64 = dtype_high('p'), dtype_high('t') | |
def get_funcs(pred, target): | |
L1 = T.nnet.binary_crossentropy(T.nnet.sigmoid(pred), target) | |
L2 = T.nnet.sigmoid_binary_crossentropy(pred, target) | |
g1, g2 = [theano.grad(L.sum(), [pred, target]) for L in [L1, L2]] | |
f1, f2 = [theano.function([pred, target], [L] + g) | |
for L, g in [(L1, g1), (L2, g2)]] | |
return f1, f2 | |
f1_32, f2_32 = get_funcs(pred32, target32) | |
f1_64, f2_64 = get_funcs(pred64, target64) | |
np.random.seed(0) | |
def random_data(n=1000*1000): return np.random.randn(n).astype(low) | |
px32 = random_data() | |
lx32 = 1 / (1 + np.exp(random_data())) | |
px64, lx64 = [a.astype(high) for a in (px32, lx32)] | |
result_f1_32, result_f2_32 = f1_32(px32, lx32), f2_32(px32, lx32) | |
result_f1_64, result_f2_64 = f1_64(px64, lx64), f2_64(px64, lx64) | |
def mse(a, b): return ((a - b) ** 2).mean() | |
names = 'L', 'grad wrt pred', 'grad wrt target' | |
for gold_name, gold_result in [('BCE', result_f1_64), ('SBCE', result_f2_64)]: | |
print 'Using {} as gold standard'.format(gold_name) | |
for name, gold, bce, sbce in \ | |
zip(names, gold_result, result_f1_32, result_f2_32): | |
err_bce = mse(gold, bce) | |
err_sbce = mse(gold, sbce) | |
improvement = (err_bce - err_sbce) / err_bce | |
print '\t{}: BCE error = {}; SBCE error = {}; improvement = {}%'.format( | |
name, err_bce, err_sbce, 100*improvement) |
Author
jeffdonahue
commented
Mar 23, 2017
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment