Last active
June 21, 2017 16:08
-
-
Save ihsgnef/ac01fbd8eaba01145c8a048a0e8a0678 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
from collections import defaultdict | |
import torch | |
from torch.autograd import Variable | |
from vocab import PAD_ID | |
class Iterator(object): | |
def __init__(self, dataset, batch_size, bucket_size=4, | |
shuffle=True): | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.bucket_size = bucket_size | |
self.shuffle = shuffle | |
self.epoch = 0 | |
self.iteration = 0 | |
self.batch_index = 0 | |
self.is_end_epoch = False | |
self.create_batches() | |
def create_batches(self): | |
self.batches = [] | |
buckets = defaultdict(list) | |
for question, answer in self.dataset: | |
if len(question) == 0: | |
continue | |
while len(question) % self.bucket_size > 0: | |
question.append(PAD_ID) | |
buckets[len(question)].append((question, answer)) | |
for samples in buckets.values(): | |
for i in range(0, len(samples), self.batch_size): | |
questions, answers = zip(*samples[i : i + self.batch_size]) | |
self.batches.append((questions, answers)) | |
@property | |
def size(self): | |
return len(self.batches) | |
def finalize(self, reset=False): | |
if self.shuffle: | |
random.shuffle(self.batches) | |
if reset: | |
self.epoch = 0 | |
self.iteration = 0 | |
self.batch_index = 0 | |
def next_batch(self, device=-1, train=True): | |
self.iteration += 1 | |
if self.batch_index == 0: | |
self.epoch += 1 | |
self.is_end_epoch = (self.batch_index == self.size - 1) | |
questions, answers = self.batches[self.batch_index] | |
questions = torch.LongTensor(questions).t() # length, batch_size | |
answers = torch.LongTensor(answers) | |
self.batch_index = (self.batch_index + 1) % self.size | |
if device != -1: | |
questions = questions.cuda(device) | |
answers = answers.cuda(device) | |
else: | |
questions = questions.contiguous() | |
answers = answers.contiguous() | |
questions = Variable(questions, volatile=not train) | |
answers = Variable(answers, volatile=not train) | |
return questions, answers | |
@property | |
def epoch_detail(self): | |
return self.size, self.iteration, self.iteration / self.size |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from collections import defaultdict | |
from vocab import Vocab | |
from question_database import QuestionDatabase | |
from iterator import Iterator | |
def preprocess(all_questions): | |
question_vocab = Vocab() | |
answer_vocab = Vocab() | |
for qnum, question in all_questions.items(): | |
if not question.fold == 'guesstrain': | |
continue | |
text = ' '.join(question.text.values()).strip() | |
# do more careful preprocessing here | |
words = re.sub(r'\W+', ' ', text).split() | |
for word in words: | |
question_vocab.add(word) | |
answer_vocab.add(question.page) | |
question_vocab.finish() | |
answer_vocab.finish() | |
dataset = defaultdict(lambda: []) | |
for qnum, question in all_questions.items(): | |
if not question.fold in ['guesstrain', 'guessdev']: | |
continue | |
text = ' '.join(question.text.values()).strip() | |
# do more careful preprocessing here | |
words = re.sub(r'\W+', ' ', text).split() | |
words = question_vocab.sent2ids(words) | |
answer = answer_vocab.word2id(question.page) | |
if question.fold == 'guesstrain': | |
dataset['train'].append((words, answer)) | |
if question.fold == 'guessdev': | |
dataset['dev'].append((words, answer)) | |
return question_vocab, answer_vocab, dataset | |
def main(): | |
batch_size = 64 | |
all_questions = QuestionDatabase().all_questions() # dict | |
question_vocab, answer_vocab, dataset = preprocess(all_questions) | |
print(len(dataset['train']), len(dataset['dev'])) | |
print(len(question_vocab), len(answer_vocab)) | |
print(dataset['train'][0]) | |
iterators = dict() | |
for fold in dataset.keys(): | |
iterators[fold] = Iterator(dataset[fold], batch_size) | |
question_batch, answer_batch = iterators['train'].next_batch() | |
print(question_batch.size(), answer_batch.size()) | |
for i in range(len(iterators['train'])): | |
question_batch, answer_batch = iterators['train'].next_batch() | |
# training | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
class Question: | |
def __init__(self, qnum, answer, category, naqt, protobowl, | |
tournaments, page, fold): | |
self.qnum = qnum | |
self.answer = answer | |
self.category = category | |
self.naqt = naqt | |
self.protobowl = protobowl | |
self.tournaments = tournaments | |
self.page = page | |
self.fold = fold | |
self.text = {} | |
def add_text(self, sent, text): | |
self.text[sent] = text | |
class QuestionDatabase: | |
def __init__(self, location='2017_05_25.db'): | |
self._conn = sqlite3.connect(location) | |
def query(self, command, arguments): | |
questions = {} | |
c = self._conn.cursor() | |
command = 'select id, page, category, answer, ' + \ | |
'tournament, naqt, protobowl, fold ' + command | |
c.execute(command, arguments) | |
for qnum, page, _, answer, tournaments, naqt, protobowl, fold in c: | |
questions[qnum] = Question(qnum, answer, None, naqt, protobowl, tournaments, page, fold) | |
for qnum in questions: | |
command = 'select sent, raw from text where question=? order by sent asc' | |
c.execute(command, (qnum, )) | |
for sentence, text in c: | |
questions[qnum].add_text(sentence, text) | |
return questions | |
def all_questions(self): | |
return self.query('FROM questions where page != ""', ()) | |
if __name__ == '__main__': | |
db = QuestionDatabase() | |
qs = db.all_questions() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
PAD_ID = 0 | |
UNK_ID = 1 | |
PAD = "<PAD>" | |
UNK = "<UNK>" | |
class Vocab(object): | |
def __init__(self): | |
self.word_count = defaultdict(lambda: 0) | |
def add(self, word): | |
self.word_count[word] += 1 | |
def finish(self, size=50000): | |
word_count = dict(self.word_count) | |
words = sorted(word_count.items(), key=lambda x: x[1], reverse=True) | |
self.i2w = [PAD, UNK] | |
self.i2w += [x[0] for x in words[:size]] | |
self.w2i = dict((w,i) for i, w in enumerate(self.i2w)) | |
def __len__(self): | |
return len(self.i2w) | |
def word2id(self, word): | |
return self.w2i.get(word, UNK_ID) | |
def id2word(self, i): | |
return self.i2w[i] | |
def sent2ids(self, sentence): | |
return [self.word2id(w) for w in sentence] | |
def ids2sent(self, ids): | |
return [self.id2word(x) for x in ids if x != PAD_ID] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment