Skip to content

Instantly share code, notes, and snippets.

@seanie12
Created February 22, 2019 01:46
Show Gist options
  • Save seanie12/7ad4450373521df1f450c1c197eeabd4 to your computer and use it in GitHub Desktop.
Save seanie12/7ad4450373521df1f450c1c197eeabd4 to your computer and use it in GitHub Desktop.
using dataset
import json
import logging
import os
import shutil
import torch
import torch.utils.data as data
import config
############################ CONSTANTS ########################################
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]'
START_TOKEN = '[START]'
STOP_TOKEN = '[STOP]'
PAD_ID = 0
UNK_ID = 1
START_ID = 2
STOP_ID = 3
# custom dataset for text-summarization extended from https://github.com/yunjey/seq2seq-dataloader
class SummaryDataset(data.Dataset):
def __init__(self, datas, max_length=400, pointer_generator=False, debug=False):
"""
:param datas: list of Data instances
:param max_length: maximum length of article
:param pointer_generator: whether to use pointer-generator
:param debug: debugging mode True / False
"""
self.articles_ids = [d.article_ids[:max_length] for d in datas]
self.article_ids_extended = [d.article_ids_extended_vocab[:max_length] for d in datas]
self.abstracts_ids = [d.abstract_ids for d in datas]
self.abstract_ids_extended = [d.abstract_ids_extended_vocab for d in datas]
self.articles_oovs = [d.articles_oov_lst for d in datas]
self.num_total_seqs = len(datas)
self.pointer_generator = pointer_generator
assert len(self.articles_ids) == len(self.abstracts_ids), \
"the number of articles and abstracts should be the same"
if debug is True:
self.articles_ids = self.articles_ids[:100]
self.abstracts_ids = self.abstracts_ids[:100]
self.num_total_seqs = len(self.articles_ids)
def __getitem__(self, index):
article = torch.Tensor(self.articles_ids[index])
article_extended = torch.Tensor(self.article_ids_extended[index])
abstract_input = torch.Tensor([START_ID] + self.abstracts_ids[index])
if self.pointer_generator:
abstract_target = torch.Tensor(self.abstract_ids_extended[index] + [STOP_ID])
else:
abstract_target = torch.Tensor(self.abstracts_ids[index] + [STOP_ID])
return article, article_extended, abstract_input, abstract_target
def __len__(self):
return self.num_total_seqs
def collate_fn(data):
"""
create mini-batch tensors from the list of tuples (src_seq, trg_seq)
we should build a custom collate_fn rather than using default collate_fn,
because merging sequences (including padding) is not supported in default
sequences are padded to the maximum length of mini-batch sequences (dynamic padding)
:param data: list of tuple (art_seq, art_extended, abs_input_seq, abs_trg_seq),
art_seq: torch tensor of shape (?); variable length
art_extended : torch tensor of shape (?); variable length
abs_input_seq: torch tensor of shape (?); variable length
abs_trg_seq : torch tensor of shape (?); variable length
:return:
art_seqs: torch tensor of shape (batch_size, padded_length)
art_lengths: list of length (batch_size); valid length for each padded source sequence
abs_input_seqs: torch tensor fo shape (batch_size, padded_length)
abs_lengths: list of length (batch_size); valid length for each padded target sequence
abs_trg_seqs: torch tensor fo shape (batch_size, padded_length)
"""
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
data.sort(key=lambda x: len(x[0]), reverse=True)
art_seqs, art_extended_seq, abs_input_seqs, abs_trg_seqs = zip(*data)
art_seqs, art_lengths = merge(art_seqs)
abs_input_seqs, abs_lengths = merge(abs_input_seqs)
abs_trg_seqs, _ = merge(abs_trg_seqs)
return art_seqs, art_extended_seq, art_lengths, abs_input_seqs, abs_trg_seqs, abs_lengths
def get_loader(datas, batch_size=32, debug=False, shuffle=True):
dataset = SummaryDataset(datas, debug=debug,
max_length=config.max_length,
pointer_generator=config.pointer_generator)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=8,
shuffle=shuffle,
collate_fn=collate_fn)
return dataloader
def cal_running_avg_loss(loss, running_avg_loss, decay=0.99):
if running_avg_loss == 0:
return loss
else:
running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
return running_avg_loss
def get_logger(dir):
# if directory exits, remove it and make new directory
if os.path.exists(dir):
shutil.rmtree(dir)
os.mkdir(dir)
logger = logging.getLogger("seq2seq_logger")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(dir + "/result.log")
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
return logger
def outputids2words(id_list, idx2word, article_oovs=None):
words = []
for idx in id_list:
try:
word = idx2word(idx)
except ValueError:
print("there is no such a word in the predefined vocab")
if article_oovs is not None:
article_oov_idx = idx - len(idx2word)
try:
word = article_oovs[article_oov_idx]
except ValueError:
print("there's no such a word in extended vocab")
words.append(word)
return words
def extract_vocab(vocab_file, max_size=None):
""" Vocab extractor
This function open and read a vocab file, and then extract the vocab to be
in order to build 2 dictionaries :
* word2id : Given a word, return the corresponding ID
* id2word : Given an ID, return the corresponding ID
The vocab file is supposed to contain "<word> <frequency>" on each line,
sorted in descending order.
Args:
vocab_file (str): Location of the vocab file to extract.
max_size (str, optional): Maximum allowed size for the vocabulary.
Defaults to None (no maximum size).
Returns:
dict: Dictionary word => ID
dict: Dictionary ID => word
"""
word2id = dict()
id2word = dict()
cnt = 0
# Assign ID 0, 1, 2, 3 to special tokens
for t in [PAD_TOKEN, UNKNOWN_TOKEN, START_TOKEN, STOP_TOKEN]:
word2id[t] = cnt
id2word[cnt] = t
cnt += 1
# Read vocab file line by line until reaching its end or max_size
with open(vocab_file, 'r') as f:
for i, line in enumerate(f):
pieces = line.split()
if len(pieces) != 2:
print("<Warning> Incorrect line in vocab : {} at {}".format(line, i + 1))
continue
word, frequency = pieces
assert word not in [PAD_TOKEN, UNKNOWN_TOKEN, START_TOKEN,
STOP_TOKEN, SENTENCE_START, SENTENCE_END], \
"Error : {} found in vocab file".format(word)
assert word not in word2id, "Error : {} duplicate found in vocab " \
"file".format(word)
word2id[word] = cnt
id2word[cnt] = word
cnt += 1
if max_size is not None and cnt >= max_size:
break
return word2id, id2word
def process_data(data_list, word2id):
""" Process a list of Data to create a sequence of ID
This function treats a whole list of data. For each data, the text of the
abstract and the article is processed into a sequence of ID given a vocab.
The data list is updated with the new sequence (of ID).
Args:
data_list (list of Data): List of data to process.
word2id (dict): Dictionary (from vocab) to use to process text data.
"""
for data in data_list:
# for article
words = data.article.split()
article_ids, article_ids_extended_vocab, oov_lst = article2ids(words, word2id)
data.article_ids = article_ids
data.article_ids_extended_vocab = article_ids_extended_vocab
data.articles_oov_lst = oov_lst
# for abstract
abstract = data.abstract.replace(SENTENCE_START, '').replace(SENTENCE_END, '')
abstract_words = abstract.split()
data.clean_abstract = ' '.join(abstract_words)
abstract_ids, abstract_ids_extended_vocab = abstract2ids(abstract_words,
word2id,
oov_lst)
data.abstract_ids = abstract_ids
data.abstract_ids_extended_vocab = abstract_ids_extended_vocab
def article2ids(words, word2idx):
"""
:param words: list of words
:param word2idx: dictionary which maps word to index
:return:
ids: list of indices oov word is represented by UNK id
oov_lst : list of oov words from article
ids_extended_vocab : list of indices. idx of oov word is
represented by vocab_size + index of oov_lst
"""
ids = []
oov_lst = []
ids_extended_vocab = []
for word in words:
try:
idx = word2idx[word]
ids.append(idx)
ids_extended_vocab.append(idx)
except KeyError:
if word not in oov_lst:
oov_lst.append(word)
idx_extended = len(word2idx) + oov_lst.index(word)
ids.append(word2idx[UNKNOWN_TOKEN])
ids_extended_vocab.append(idx_extended)
return ids, ids_extended_vocab, oov_lst
def abstract2ids(words, word2idx, article_oovs):
ids = []
ids_extended_vocab = []
unk_id = word2idx[UNKNOWN_TOKEN]
vocab_size = len(word2idx)
for word in words:
# oov word
if word not in word2idx:
ids.append(word2idx[UNKNOWN_TOKEN])
if word in article_oovs:
idx = vocab_size + article_oovs.index(word)
ids_extended_vocab.append(idx)
else:
ids_extended_vocab.append(unk_id)
else:
ids.append(word2idx[word])
ids_extended_vocab.append(word2idx[word])
return ids, ids_extended_vocab
def seq2seq(seq, translator):
""" Process a sequence into another sequence using a translator.
Translate a sequence to another sequence. It can be a sequence of word into
a sequence of ID or the opposite.
Args:
seq (list): Sequence to translate.
translator (dict): Dictionary to use to translate the given sequence
into another sequence.
Returns:
list: Translated sequence.
"""
translated_seq = []
for elem in seq:
try:
translated_seq.append(translator[elem])
except KeyError:
translated_seq.append(translator[UNKNOWN_TOKEN])
translated_seq.append(translator[STOP_TOKEN])
return translated_seq
class Data(object):
""" Class representing one data.
Attributes:
article (str): Article to summarize.
abstract (str, optional): Corresponding abstract. Defaults to `None`.
"""
def __init__(self, article, abstract=None):
self.article = article
self.abstract = abstract
self.clean_abstract = None
self.article_ids = None
self.abstract_ids = None
self.articles_oov_lst = None
self.article_ids_extended_vocab = None
self.abstract_ids_extended_vocab = None
class CnnDailymailProcessor(object):
""" Class for extracting data.
Extract data from the CNN / Daily Mail dataset.
"""
def get_train_data(self, data_dir):
""" Gets a list of `Data` for the train subset.
Args:
data_dir (str): Location of the folder where the files are located.
Returns:
list of Data: Extracted Data.
"""
return self._read_jsonl_to_data(os.path.join(data_dir, "train.jsonl"))
def get_val_data(self, data_dir):
""" Gets a list of `Data` for the validation subset.
Args:
data_dir (str): Location of the folder where the files are located.
Returns:
list of Data: Extracted Data.
"""
return self._read_jsonl_to_data(os.path.join(data_dir, "val.jsonl"))
def get_test_data(self, data_dir):
""" Gets a list of `Data` for the test subset.
Args:
data_dir (str): Location of the folder where the files are located.
Returns:
list of Data: Extracted Data.
"""
return self._read_jsonl_to_data(os.path.join(data_dir, "test.jsonl"))
def _read_jsonl_to_data(self, input_file):
""" Method to read JSONL files and transform it into a Data list.
This method reads the content of JSONL files and output a list of
corresponding `Data`.
Args:
input_file (str): Location of the file to read.
Returns:
list of `Data`: List of Data extracted from the JSONL file.
"""
data = []
with open(input_file, "r") as f:
for line in f:
raw_data = json.loads(line)
data.append(Data(article=raw_data['article'],
abstract=raw_data['abstract']))
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment