Skip to content

Instantly share code, notes, and snippets.

@dpiponi
Created June 11, 2019 04:39
Show Gist options
  • Save dpiponi/c6d27fd3f05c266d7f2003742e3d363e to your computer and use it in GitHub Desktop.
Save dpiponi/c6d27fd3f05c266d7f2003742e3d363e to your computer and use it in GitHub Desktop.
A neural network to learn like in the paper "Transitive inference in Polistes paper wasps"
from __future__ import print_function
import sys
# If you don't know jax, check it out
# https://github.com/google/jax
import jax.numpy as np
import jax.scipy as scipy
from jax import jit, random, value_and_grad
# We're training the simplest neural network to infer an ordering on a
# set given limited examples.
# The set is {A, B, C, D, E} and the ordering is
# A < B < C < D < E
# We train by giving examples consisting of distinct pairs (x, y)
# encoded as a pair of 5D 1-hot vectors
# and we expect result 1 if x < y and 0 if x > y.
# We train with the same examples as in
# http://sci-hub.tw/downloads/2019-05-09/7a/[email protected]
# ie. (A, B), (B, C), ..., (D, E)
# as well as the flipped pairs (B, A), ..., (E, D).
def model(w, bias, inputs):
predictions = scipy.special.expit(np.matmul(inputs, w) + bias)
return predictions
# Compute loss for entire training set at every iteration
def loss(w, bias, inputs, correct):
predictions = model(w, bias, inputs)
r = np.sum(-correct*np.log(predictions)-(1.-correct)*np.log(1.-predictions))
return r
key = random.PRNGKey(666)
keys = random.split(key, 2)
weights = random.normal(keys[0], shape=(10,))
bias = random.normal(keys[1])
# This is the entire training set
inputs = np.array([
[1, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 1, 0]
], dtype=np.float32)
correct = np.array([1, 1, 1, 1, 0, 0, 0, 0], dtype=np.float32)
# loss, (dw, db) = value_and_grad(loss, argnums=[0, 1])(weights, bias, inputs, correct)
# print(loss, dw, db)
# sys.exit(1)
# Small networks allow high training rates
alpha = 1.0
f = jit(value_and_grad(loss, argnums=[0, 1]))
# We're going to use 25 epochs with our tiny training set
for i in range(25):
print("i=", i)
l, (dw, db) = f(weights, bias, inputs, correct)
print("loss =", l)
weights -= alpha*dw
bias -= alpha*db
# The original training set is learnt nicely
print(model(weights, bias, inputs))
# And see how it seems to have assumed transitivity in the way it has very
# high confidence that A<E and B<D.
print(model(weights, bias, np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=np.float32)))
print(model(weights, bias, np.array([0, 0, 0, 0, 1, 1, 0, 0, 0, 0], dtype=np.float32)))
print(model(weights, bias, np.array([0, 1, 0, 0, 0, 0, 0, 0, 1, 0], dtype=np.float32)))
print(model(weights, bias, np.array([0, 0, 0, 1, 0, 0, 1, 0, 0, 0], dtype=np.float32)))
# It's not a surprise really given how neural nets work
@dpiponi
Copy link
Author

dpiponi commented Jun 11, 2019

Did you see my tweet the other day? :)
https://twitter.com/sigfpe/status/1137122070548996096?s=20

@cipherself
Copy link

hah ๐Ÿ™‚

@dpiponi
Copy link
Author

dpiponi commented Jun 12, 2019

And the explicit mentions of float32 are my Haskell showing :)

@cipherself
Copy link

I do the same when using numpy, it's only right ๐Ÿ˜€

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment