Skip to content

Instantly share code, notes, and snippets.

@cedrickchee
Last active August 4, 2018 09:50
Show Gist options
  • Save cedrickchee/9bf93cb2a17ad8e5f779518a549bf627 to your computer and use it in GitHub Desktop.
Save cedrickchee/9bf93cb2a17ad8e5f779518a549bf627 to your computer and use it in GitHub Desktop.
Keras implementation of DeepMind's Neural Arithmetic Logic Units (NALU). Paper: https://arxiv.org/abs/1808.00508
import numpy as np
import keras.backend as K
from keras.layers import *
from keras.models import *
import tensorflow as tf
class Nalu(Layer):
def __init__(self, units, krnl_init="glorot_uniform", **kwargs):
if "inp_shp" not in kwargs and "inp_dim" in kwargs:
kwargs["inp_shp"] = (kwargs.pop("inp_dim"),)
super(Nalu, self).__init__(**kwargs)
self.units = units
self.inp_spec = InputSpec(min_ndim=2)
self.krnl_init = initializers.get(krnl_init)
def get_config(self):
conf = {
"units": self.units,
"krnl_init": initializers.serialize(self.krnl_init),
}
base_conf = super(Dense, self).get_config()
base_conf_lst = list(base_conf.items())
conf_lst = list(conf.items())
return dict(base_conf_lst + conf_lst)
def build(self, inp_shp):
assert len(inp_shp) >= 2
inp_dim = inp_shp[-1]
self.W_hat = self.add_weight(
shape=(inp_dim, self.units), initializer=self.krnl_init, name="W_hat"
)
self.M_hat = self.add_weight(
shape=(inp_dim, self.units), initializer=self.krnl_init, name="M_hat"
)
self.G = self.add_weight(
shape=(inp_dim, self.units), initializer=self.krnl_init, name="G"
)
self.inp_spec = InputSpec(min_ndim=2, axes={-1: inp_dim})
self.built = True
def call(self, inputs):
W_act = K.tanh(self.W_hat)
M_act = K.sigmoid(self.M_hat)
W = W_act * M_act
m = K.exp(K.dot(K.log(K.abs(inputs) + 1e-7), W))
g = K.sigmoid(K.dot(inputs, self.G))
a = K.dot(x, W)
out = g * a + (1 - g) * m
return out
def compute_out_shp(self, inp_shp):
assert inp_shp and len(inp_shp) >= 2
assert inp_shp[-1]
out_shp = list(inp_shp)
out_shp[-1] = self.units
return tuple(out_shp)
if __name__ == "__main__":
x = Input((10,))
y = Nalu(1)(x)
model = Model(x, y)
model.compile("adam", "mse")
model.fit(
np.random.rand(128, 10), np.random.rand(128, 1), batch_size=128, epochs=100
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment