Created
February 8, 2016 22:20
-
-
Save ceberly/d21580bc53931c28b782 to your computer and use it in GitHub Desktop.
Seinfeld2000 RNN generator server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local async = require 'async' | |
local fiber = require 'async.fiber' | |
local torch = require 'torch' | |
local nn = require 'nn' | |
local nngraph = require 'nngraph' | |
local OneHot = require 'util.OneHot' | |
local misc = require 'util.misc' | |
local model = "model.t7" | |
local checkpoint = torch.load(model) | |
local protos = checkpoint.protos | |
protos.rnn:evaluate() | |
local vocab = checkpoint.vocab | |
local ivocab = {} | |
for c,i in pairs(vocab) do ivocab[i] = c end | |
local current_state = {} | |
for L = 1,checkpoint.opt.num_layers do | |
-- c and h for all layers | |
local h_init = torch.zeros(1, checkpoint.opt.rnn_size):double() | |
table.insert(current_state, h_init:clone()) | |
table.insert(current_state, h_init:clone()) | |
end | |
local respond = function(req, res) | |
local author = req.body["author"] or "Uncle Leo" | |
local prompt = req.body["prompt"] or "Kram berst into the apartment" | |
local resp = " " .. prompt .. "\n\n" | |
local prev_char = '' | |
local state_size = #current_state | |
local temperature = 1 | |
local seed_text = prompt | |
for c in seed_text:gmatch'.' do | |
prev_char = torch.Tensor{vocab[c]} | |
local lst = protos.rnn:forward{prev_char, unpack(current_state)} | |
current_state = {} | |
for i=1,state_size do table.insert(current_state, lst[i]) end | |
prediction = lst[#lst] | |
end | |
for i=1, 1500 do | |
prediction:div(temperature) -- scale by temperature | |
local probs = torch.exp(prediction):squeeze() | |
probs:div(torch.sum(probs)) -- renormalize so probs sum to one | |
prev_char = torch.multinomial(probs:float(), 1):resize(1):float() | |
-- forward the rnn for next character | |
local lst = protos.rnn:forward{prev_char, unpack(current_state)} | |
current_state = {} | |
for i=1,state_size do table.insert(current_state, lst[i]) end | |
prediction = lst[#lst] -- last element holds the log probabilities | |
resp = resp .. ivocab[prev_char[1]] | |
end | |
res(resp, {['Content-Type']='text/html'}) | |
end | |
async.http.listen('http://0.0.0.0:8080/', function(req,res) | |
fiber(function() | |
fiber.wait(respond, { req, res }, 8) | |
end) | |
end) | |
print('server listening to port 8082') | |
async.go() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment