Skip to content

Instantly share code, notes, and snippets.

@ceberly
Created February 8, 2016 22:20
Show Gist options
  • Save ceberly/d21580bc53931c28b782 to your computer and use it in GitHub Desktop.
Save ceberly/d21580bc53931c28b782 to your computer and use it in GitHub Desktop.
Seinfeld2000 RNN generator server
local async = require 'async'
local fiber = require 'async.fiber'
local torch = require 'torch'
local nn = require 'nn'
local nngraph = require 'nngraph'
local OneHot = require 'util.OneHot'
local misc = require 'util.misc'
local model = "model.t7"
local checkpoint = torch.load(model)
local protos = checkpoint.protos
protos.rnn:evaluate()
local vocab = checkpoint.vocab
local ivocab = {}
for c,i in pairs(vocab) do ivocab[i] = c end
local current_state = {}
for L = 1,checkpoint.opt.num_layers do
-- c and h for all layers
local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double()
table.insert(current_state, h_init:clone())
table.insert(current_state, h_init:clone())
end
local respond = function(req, res)
local author = req.body["author"] or "Uncle Leo"
local prompt = req.body["prompt"] or "Kram berst into the apartment"
local resp = " " .. prompt .. "\n\n"
local prev_char = ''
local state_size = #current_state
local temperature = 1
local seed_text = prompt
for c in seed_text:gmatch'.' do
prev_char = torch.Tensor{vocab[c]}
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst]
end
for i=1, 1500 do
prediction:div(temperature) -- scale by temperature
local probs = torch.exp(prediction):squeeze()
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
-- forward the rnn for next character
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
resp = resp .. ivocab[prev_char[1]]
end
res(resp, {['Content-Type']='text/html'})
end
async.http.listen('http://0.0.0.0:8080/', function(req,res)
fiber(function()
fiber.wait(respond, { req, res }, 8)
end)
end)
print('server listening to port 8082')
async.go()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment