Created
June 11, 2019 04:39
-
-
Save dpiponi/c6d27fd3f05c266d7f2003742e3d363e to your computer and use it in GitHub Desktop.
A neural network to learn like in the paper "Transitive inference in Polistes paper wasps"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import sys | |
# If you don't know jax, check it out | |
# https://github.com/google/jax | |
import jax.numpy as np | |
import jax.scipy as scipy | |
from jax import jit, random, value_and_grad | |
# We're training the simplest neural network to infer an ordering on a | |
# set given limited examples. | |
# The set is {A, B, C, D, E} and the ordering is | |
# A < B < C < D < E | |
# We train by giving examples consisting of distinct pairs (x, y) | |
# encoded as a pair of 5D 1-hot vectors | |
# and we expect result 1 if x < y and 0 if x > y. | |
# We train with the same examples as in | |
# http://sci-hub.tw/downloads/2019-05-09/7a/[email protected] | |
# ie. (A, B), (B, C), ..., (D, E) | |
# as well as the flipped pairs (B, A), ..., (E, D). | |
def model(w, bias, inputs): | |
predictions = scipy.special.expit(np.matmul(inputs, w) + bias) | |
return predictions | |
# Compute loss for entire training set at every iteration | |
def loss(w, bias, inputs, correct): | |
predictions = model(w, bias, inputs) | |
r = np.sum(-correct*np.log(predictions)-(1.-correct)*np.log(1.-predictions)) | |
return r | |
key = random.PRNGKey(666) | |
keys = random.split(key, 2) | |
weights = random.normal(keys[0], shape=(10,)) | |
bias = random.normal(keys[1]) | |
# This is the entire training set | |
inputs = np.array([ | |
[1, 0, 0, 0, 0, 0, 1, 0, 0, 0], | |
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0], | |
[0, 0, 1, 0, 0, 0, 0, 0, 1, 0], | |
[0, 0, 0, 1, 0, 0, 0, 0, 0, 1], | |
[0, 1, 0, 0, 0, 1, 0, 0, 0, 0], | |
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0], | |
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0], | |
[0, 0, 0, 0, 1, 0, 0, 0, 1, 0] | |
], dtype=np.float32) | |
correct = np.array([1, 1, 1, 1, 0, 0, 0, 0], dtype=np.float32) | |
# loss, (dw, db) = value_and_grad(loss, argnums=[0, 1])(weights, bias, inputs, correct) | |
# print(loss, dw, db) | |
# sys.exit(1) | |
# Small networks allow high training rates | |
alpha = 1.0 | |
f = jit(value_and_grad(loss, argnums=[0, 1])) | |
# We're going to use 25 epochs with our tiny training set | |
for i in range(25): | |
print("i=", i) | |
l, (dw, db) = f(weights, bias, inputs, correct) | |
print("loss =", l) | |
weights -= alpha*dw | |
bias -= alpha*db | |
# The original training set is learnt nicely | |
print(model(weights, bias, inputs)) | |
# And see how it seems to have assumed transitivity in the way it has very | |
# high confidence that A<E and B<D. | |
print(model(weights, bias, np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=np.float32))) | |
print(model(weights, bias, np.array([0, 0, 0, 0, 1, 1, 0, 0, 0, 0], dtype=np.float32))) | |
print(model(weights, bias, np.array([0, 1, 0, 0, 0, 0, 0, 0, 1, 0], dtype=np.float32))) | |
print(model(weights, bias, np.array([0, 0, 0, 1, 0, 0, 1, 0, 0, 0], dtype=np.float32))) | |
# It's not a surprise really given how neural nets work |
we try to write everything to work in both Python 2 & 3
ah, explains why you're not using xrange
.
Did you see my tweet the other day? :)
https://twitter.com/sigfpe/status/1137122070548996096?s=20
hah ๐
And the explicit mentions of float32 are my Haskell showing :)
I do the same when using numpy, it's only right ๐
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
At home I'm using Python 2 but at work we try to write everything to work in both Python 2 & 3 so I'm using imports from future at home now