Last active
April 30, 2019 19:58
-
-
Save sjmielke/df516b520e515f09999754b3fc6c1b52 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mpmath import mp | |
mp.dps = 100 | |
import torch | |
# Take some list of values that shrink to be really small in log space: | |
torch_lps = torch.log_softmax(-torch.arange(20.0), dim=0) | |
mpmath_lps = -torch.arange(20.0) | |
Z = sum([mp.exp(mp.mpf(mpmath_lps[i].item())) for i in range(len(mpmath_lps))]) | |
for i in range(len(mpmath_lps)): | |
mpmath_lps[i] = float(mp.mpf(mpmath_lps[i].item()) - mp.log(Z)) | |
for (lps, name, conclusion) in [(torch_lps, "PyTorch's logsumexp/logsoftmax", "PyTorch LSE does end up at 0 as it should... but too early!\nmpmath's more exact calculations show that we overshoot!"), (mpmath_lps, "mpmath's 100-digit wide computation", "PyTorch LSE again fails to include the smallest, but now thinks that log Z < 0!\nmpmath shows that, no, that's not true, we still overshoot, but less.")]: | |
print("\n\nWhen normalizing with", name, "\n") | |
# Now take cumulative sums in log space using logsumexp: | |
sum_naive = torch.tensor(0.0).log() | |
sum_LSE = torch.tensor(0.0).log() | |
log_remainder = torch.tensor(1.0).log() | |
mp_remainder = mp.mpf(1.0) | |
print("add this -> log sums: sum_naive sum_LSE_all sum_log1p mpmath!") | |
print("----------------------------------------------------------------------------------------") | |
for i in range(len(lps)): | |
# Version 0 is naive log sum exp | |
sum_naive = torch.log(sum_naive.exp() + lps[i].exp()) | |
# Version 1 reuses past computation results and implements binary LSE by log1p | |
if i == 0: | |
sum_log1p = lps[i] | |
else: | |
assert sum_log1p > lps[i] | |
sum_log1p = sum_log1p + (lps[i] - sum_log1p).exp().log1p() | |
# Version 2 uses torch's LSE for binary LSE | |
sum_LSE = torch.logsumexp(torch.stack([sum_LSE, lps[i]]), dim=0) | |
# Version 3 recomputes from scratch | |
sum_scratch = torch.logsumexp(lps[:i+1], dim=0) | |
# Version 4 uses the remainder for incremental computation instead! | |
log_remainder = lps[i] + (log_remainder - lps[i]).expm1().log() | |
# log_remainder = log_remainder + (-torch.exp(lps[i] - log_remainder)).log1p() | |
sum_rem = torch.log(-log_remainder.expm1()) | |
# Version 5 is arbitrary precision floating math | |
mp_remainder = mp_remainder - mp.exp(mp.mpf(lps[i].item())) | |
# Compare! | |
print(f"... + {lps[i].exp().item():12.10f} {sum_naive.item():13.10f} {sum_scratch.item():13.10f} {sum_log1p.item():13.10f} {float(mp.log1p(-mp_remainder)):13.10f}".replace('0', '\033[2m0\033[0m')) | |
print("\n" + conclusion) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment