Created
February 17, 2017 18:05
-
-
Save michaelchughes/decb8731634f708e2d1c5958acf96f96 to your computer and use it in GitHub Desktop.
Utility functions for softplus transform between positive reals and unconstrained reals.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' Define invertible, differentiable transform for arrays of positive vals. | |
Utility funcs for transforming positive vals to and from unconstrained real line. | |
This transform is sometimes called the "softplus" transformation. | |
See: https://en.wikipedia.org/wiki/Rectifier_(neural_networks) | |
To run automated tests to verify correctness of this script, do: | |
(Add -v flag for verbose output) | |
``` | |
$ python -m doctest trans_util_positive.py | |
``` | |
Common parameter | |
---------------- | |
p : 1D array, size F, of positive real values | |
Differentiable parameter | |
------------------------ | |
x: 1D array, size F, of unconstrained real values | |
Transform from common to diffable | |
--------------------------------- | |
x = log(exp(p) - 1) | |
Function `to_diffable_vec` does this in numerically stable way. | |
Transform from diffable to common | |
--------------------------------- | |
p = log(exp(x) + 1) | |
Function `to_common_arr` does this in numerically stable way. | |
Examples | |
-------- | |
>>> np.set_printoptions(suppress=False, precision=3, linewidth=80) | |
# Create 1D array of pos values ranging from 0.000001 to 1000000 | |
>>> p_F = np.asarray([1e-6, 1e-3, 1e0, 1e3, 1e6]) | |
>>> print p_F | |
[ 1.000e-06 1.000e-03 1.000e+00 1.000e+03 1.000e+06] | |
# Look at transformed values, which vary from -13 to 1000000 | |
>>> x_F = to_diffable_vec(p_F) | |
>>> print x_F | |
[ -1.382e+01 -6.907e+00 5.413e-01 1.000e+03 1.000e+06] | |
# Show its invertible | |
>>> print to_common_arr(to_diffable_vec(p_F)) | |
[ 1.000e-06 1.000e-03 1.000e+00 1.000e+03 1.000e+06] | |
>>> assert np.allclose(p_F, to_common_arr(to_diffable_vec(p_F))) | |
''' | |
import autograd.numpy as np | |
from autograd.core import primitive | |
@primitive | |
def to_common_arr(x): | |
""" Numerically stable transform from real line to positive reals | |
Returns np.log(1.0 + np.exp(x)) | |
Autograd friendly and fully vectorized | |
Args | |
---- | |
x : array of values in (-\infty, +\infty) | |
Returns | |
------- | |
ans : array of values in (0, +\infty), same size as x | |
""" | |
if not isinstance(x, float): | |
mask1 = x > 0 | |
mask0 = np.logical_not(mask1) | |
out = np.zeros_like(x) | |
out[mask0] = np.log1p(np.exp(x[mask0])) | |
out[mask1] = x[mask1] + np.log1p(np.exp(-x[mask1])) | |
return out | |
if x > 0: | |
return x + np.log1p(np.exp(-x)) | |
else: | |
return np.log1p(np.exp(x)) | |
def make_grad__to_common_arr(ans, x): | |
x = np.asarray(x) | |
def gradient_product(g): | |
return np.full(x.shape, g) * np.exp(x - ans) | |
return gradient_product | |
to_common_arr.defgrad(make_grad__to_common_arr) | |
@primitive | |
def to_diffable_vec(p): | |
""" Numerically stable transform from positive reals to real line | |
Implements np.log(np.exp(x) - 1.0) | |
Autograd friendly and fully vectorized | |
Args | |
---- | |
p : array of values in (0, +\infty) | |
Returns | |
------- | |
ans : array of values in (-\infty, +\infty), same size as p | |
""" | |
if not isinstance(p, float): | |
mask1 = p > 10.0 | |
mask0 = np.logical_not(mask1) | |
out = np.zeros_like(p) | |
out[mask0] = np.log(np.expm1(p[mask0])) | |
out[mask1] = p[mask1] + np.log1p(-np.exp(-p[mask1])) | |
return out | |
if p > 10: | |
return p + np.log1p(-np.exp(-p)) | |
else: | |
return np.log(np.expm1(p)) | |
def make_grad__to_diffable_vec(ans, x): | |
x = np.asarray(x) | |
def gradient_product(g): | |
return np.full(x.shape, g) * np.exp(x - ans) | |
return gradient_product | |
to_diffable_vec.defgrad(make_grad__to_diffable_vec) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment