Created
December 24, 2018 19:41
-
-
Save 8bit-pixies/6f301235ff6b45f14aa93ca5ccf000b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def linear_part(x, w): | |
return x * w | |
def non_linear(x, p=0.001): | |
multi = (x > 0).astype(np.float64) | |
multi[multi == 0] = 0.001 | |
return x * multi | |
def non_linear_inv(x, p=0.001): | |
multi = (x > 0).astype(np.float64) | |
multi[multi == 0] = 1/0.001 | |
return x * multi | |
x = np.random.normal(size=10) * 10 | |
w = np.random.normal(size=10) * 10 | |
out = linear_part(x, w) | |
out_1 = non_linear(out) * 0.5 | |
out2 = linear_part(x, w*0.5) | |
out_2 = non_linear(out2) | |
np.array_equal(out_1, out_2) | |
w1 = np.random.normal(size=10) * 10 | |
out_3a = non_linear(linear_part(x, w)) | |
out_3b = non_linear(linear_part(x, w1)) | |
out_3 = out_3a + out_3b # this is not closely related to anything. | |
# think of max(x1, 0) + max(x2, 0) can not be represented by a single | |
# linear equation | |
out_3c = non_linear(non_linear_inv(out_3a) + non_linear_inv(out_3b)) | |
out_4 = non_linear(linear_part(x, w+w1)) | |
out_5 = non_linear(linear_part(x, w) + linear_part(x, w1)) | |
np.isclose(out_4, out_5) | |
np.isclose(out_3c, out_4) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment