Skip to content

Instantly share code, notes, and snippets.

@marcj
Last active May 9, 2023 18:28
Show Gist options
  • Save marcj/e102b22b8d1818926d9186bf09c6c35b to your computer and use it in GitHub Desktop.
Save marcj/e102b22b8d1818926d9186bf09c6c35b to your computer and use it in GitHub Desktop.
Mini calculator LSTM
"""
This script tests how good LSTMs are at solving mini calc tasks like 1+2, 3+4, 5*5.
The challenge is that the input is as string and can have arbitrary many whitespaces between the numbers and operator.
"1+2"
" 2 * 4"
" 3 + 3"
To have the problem simple enough, only 1 digit numbers are allowed. So 1+2, but not 11+2. And only 2 operators: + and *.
"""
import random
import torch
import torch.nn as nn
import torch.optim as optim
from summary import summary
vocab_size = 10 + 3
embed_size = 3
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
max_num = 10
def tokenize(text):
# vocab: '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ' ', '+', '*'
# covert ' ', '+', '*' to 10, 11, 12
res = []
for c in text:
if c.isdigit():
res.append(int(c))
elif c == ' ':
res.append(10)
elif c == '+':
res.append(11)
elif c == '*':
res.append(12)
return torch.tensor(res, device=device)
samples = []
# generate 100 samples. sample 2 numbers and their sum, then create a string with arbitrary space between the 2 numbers
for i in range(500):
a = torch.randint(0, max_num, (1,))
b = torch.randint(0, max_num, (1,))
spaces_start = random.randint(0, 2)
spaces_middle = random.randint(0, 2)
spaces_end = random.randint(0, 2)
op = random.choice(['+', '*'])
if op == '+':
c = a + b
else:
c = a * b
string = (' ' * spaces_start) + str(a.item()) + (' ' * spaces_middle) + op + str(b.item()) + (' ' * spaces_end)
samples.append([tokenize(string), c.type(torch.float32).to(device)])
class SimpleMathNet(nn.Module):
def __init__(self):
super(SimpleMathNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, 10)
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, s):
x = self.word_embeddings(s)
_, (h_n, _) = self.lstm(x)
y = torch.relu(self.fc1(h_n.view(-1)))
y = self.fc2(y)
return y
def test(string):
input = tokenize(string)
with torch.no_grad():
y = net(input)
print(f"{string} = {y.item()} ({input} = {y}")
# Example usage
net = SimpleMathNet()
optimizer = optim.Adam(net.parameters(), lr=0.001)
summary(net, [samples[0][0]])
test(' 1+2')
def train(epochs=100):
for epoch in range(epochs):
running_loss = 0.0
for sample in samples:
x, y = sample
optimizer.zero_grad()
output = net(x)
loss = torch.nn.functional.mse_loss(output, y)
loss.backward()
running_loss += loss.item()
optimizer.step()
print(f"Epoch {epoch}: Loss: ", running_loss / len(samples))
test('1 + 1')
test('1 + 2')
test('2 + 4')
test('2 * 4')
test(' 1 + 6 ')
test(' 2 * 4 ')
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment