Created
February 22, 2019 01:46
-
-
Save seanie12/7ad4450373521df1f450c1c197eeabd4 to your computer and use it in GitHub Desktop.
using dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
import os | |
import shutil | |
import torch | |
import torch.utils.data as data | |
import config | |
############################ CONSTANTS ######################################## | |
SENTENCE_START = '<s>' | |
SENTENCE_END = '</s>' | |
PAD_TOKEN = '[PAD]' | |
UNKNOWN_TOKEN = '[UNK]' | |
START_TOKEN = '[START]' | |
STOP_TOKEN = '[STOP]' | |
PAD_ID = 0 | |
UNK_ID = 1 | |
START_ID = 2 | |
STOP_ID = 3 | |
# custom dataset for text-summarization extended from https://github.com/yunjey/seq2seq-dataloader | |
class SummaryDataset(data.Dataset): | |
def __init__(self, datas, max_length=400, pointer_generator=False, debug=False): | |
""" | |
:param datas: list of Data instances | |
:param max_length: maximum length of article | |
:param pointer_generator: whether to use pointer-generator | |
:param debug: debugging mode True / False | |
""" | |
self.articles_ids = [d.article_ids[:max_length] for d in datas] | |
self.article_ids_extended = [d.article_ids_extended_vocab[:max_length] for d in datas] | |
self.abstracts_ids = [d.abstract_ids for d in datas] | |
self.abstract_ids_extended = [d.abstract_ids_extended_vocab for d in datas] | |
self.articles_oovs = [d.articles_oov_lst for d in datas] | |
self.num_total_seqs = len(datas) | |
self.pointer_generator = pointer_generator | |
assert len(self.articles_ids) == len(self.abstracts_ids), \ | |
"the number of articles and abstracts should be the same" | |
if debug is True: | |
self.articles_ids = self.articles_ids[:100] | |
self.abstracts_ids = self.abstracts_ids[:100] | |
self.num_total_seqs = len(self.articles_ids) | |
def __getitem__(self, index): | |
article = torch.Tensor(self.articles_ids[index]) | |
article_extended = torch.Tensor(self.article_ids_extended[index]) | |
abstract_input = torch.Tensor([START_ID] + self.abstracts_ids[index]) | |
if self.pointer_generator: | |
abstract_target = torch.Tensor(self.abstract_ids_extended[index] + [STOP_ID]) | |
else: | |
abstract_target = torch.Tensor(self.abstracts_ids[index] + [STOP_ID]) | |
return article, article_extended, abstract_input, abstract_target | |
def __len__(self): | |
return self.num_total_seqs | |
def collate_fn(data): | |
""" | |
create mini-batch tensors from the list of tuples (src_seq, trg_seq) | |
we should build a custom collate_fn rather than using default collate_fn, | |
because merging sequences (including padding) is not supported in default | |
sequences are padded to the maximum length of mini-batch sequences (dynamic padding) | |
:param data: list of tuple (art_seq, art_extended, abs_input_seq, abs_trg_seq), | |
art_seq: torch tensor of shape (?); variable length | |
art_extended : torch tensor of shape (?); variable length | |
abs_input_seq: torch tensor of shape (?); variable length | |
abs_trg_seq : torch tensor of shape (?); variable length | |
:return: | |
art_seqs: torch tensor of shape (batch_size, padded_length) | |
art_lengths: list of length (batch_size); valid length for each padded source sequence | |
abs_input_seqs: torch tensor fo shape (batch_size, padded_length) | |
abs_lengths: list of length (batch_size); valid length for each padded target sequence | |
abs_trg_seqs: torch tensor fo shape (batch_size, padded_length) | |
""" | |
def merge(sequences): | |
lengths = [len(seq) for seq in sequences] | |
padded_seqs = torch.zeros(len(sequences), max(lengths)).long() | |
for i, seq in enumerate(sequences): | |
end = lengths[i] | |
padded_seqs[i, :end] = seq[:end] | |
return padded_seqs, lengths | |
data.sort(key=lambda x: len(x[0]), reverse=True) | |
art_seqs, art_extended_seq, abs_input_seqs, abs_trg_seqs = zip(*data) | |
art_seqs, art_lengths = merge(art_seqs) | |
abs_input_seqs, abs_lengths = merge(abs_input_seqs) | |
abs_trg_seqs, _ = merge(abs_trg_seqs) | |
return art_seqs, art_extended_seq, art_lengths, abs_input_seqs, abs_trg_seqs, abs_lengths | |
def get_loader(datas, batch_size=32, debug=False, shuffle=True): | |
dataset = SummaryDataset(datas, debug=debug, | |
max_length=config.max_length, | |
pointer_generator=config.pointer_generator) | |
dataloader = data.DataLoader(dataset=dataset, | |
batch_size=batch_size, | |
num_workers=8, | |
shuffle=shuffle, | |
collate_fn=collate_fn) | |
return dataloader | |
def cal_running_avg_loss(loss, running_avg_loss, decay=0.99): | |
if running_avg_loss == 0: | |
return loss | |
else: | |
running_avg_loss = running_avg_loss * decay + (1 - decay) * loss | |
return running_avg_loss | |
def get_logger(dir): | |
# if directory exits, remove it and make new directory | |
if os.path.exists(dir): | |
shutil.rmtree(dir) | |
os.mkdir(dir) | |
logger = logging.getLogger("seq2seq_logger") | |
logger.setLevel(logging.DEBUG) | |
fh = logging.FileHandler(dir + "/result.log") | |
fh.setLevel(logging.DEBUG) | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.DEBUG) | |
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
fh.setFormatter(formatter) | |
ch.setFormatter(formatter) | |
logger.addHandler(fh) | |
logger.addHandler(ch) | |
return logger | |
def outputids2words(id_list, idx2word, article_oovs=None): | |
words = [] | |
for idx in id_list: | |
try: | |
word = idx2word(idx) | |
except ValueError: | |
print("there is no such a word in the predefined vocab") | |
if article_oovs is not None: | |
article_oov_idx = idx - len(idx2word) | |
try: | |
word = article_oovs[article_oov_idx] | |
except ValueError: | |
print("there's no such a word in extended vocab") | |
words.append(word) | |
return words | |
def extract_vocab(vocab_file, max_size=None): | |
""" Vocab extractor | |
This function open and read a vocab file, and then extract the vocab to be | |
in order to build 2 dictionaries : | |
* word2id : Given a word, return the corresponding ID | |
* id2word : Given an ID, return the corresponding ID | |
The vocab file is supposed to contain "<word> <frequency>" on each line, | |
sorted in descending order. | |
Args: | |
vocab_file (str): Location of the vocab file to extract. | |
max_size (str, optional): Maximum allowed size for the vocabulary. | |
Defaults to None (no maximum size). | |
Returns: | |
dict: Dictionary word => ID | |
dict: Dictionary ID => word | |
""" | |
word2id = dict() | |
id2word = dict() | |
cnt = 0 | |
# Assign ID 0, 1, 2, 3 to special tokens | |
for t in [PAD_TOKEN, UNKNOWN_TOKEN, START_TOKEN, STOP_TOKEN]: | |
word2id[t] = cnt | |
id2word[cnt] = t | |
cnt += 1 | |
# Read vocab file line by line until reaching its end or max_size | |
with open(vocab_file, 'r') as f: | |
for i, line in enumerate(f): | |
pieces = line.split() | |
if len(pieces) != 2: | |
print("<Warning> Incorrect line in vocab : {} at {}".format(line, i + 1)) | |
continue | |
word, frequency = pieces | |
assert word not in [PAD_TOKEN, UNKNOWN_TOKEN, START_TOKEN, | |
STOP_TOKEN, SENTENCE_START, SENTENCE_END], \ | |
"Error : {} found in vocab file".format(word) | |
assert word not in word2id, "Error : {} duplicate found in vocab " \ | |
"file".format(word) | |
word2id[word] = cnt | |
id2word[cnt] = word | |
cnt += 1 | |
if max_size is not None and cnt >= max_size: | |
break | |
return word2id, id2word | |
def process_data(data_list, word2id): | |
""" Process a list of Data to create a sequence of ID | |
This function treats a whole list of data. For each data, the text of the | |
abstract and the article is processed into a sequence of ID given a vocab. | |
The data list is updated with the new sequence (of ID). | |
Args: | |
data_list (list of Data): List of data to process. | |
word2id (dict): Dictionary (from vocab) to use to process text data. | |
""" | |
for data in data_list: | |
# for article | |
words = data.article.split() | |
article_ids, article_ids_extended_vocab, oov_lst = article2ids(words, word2id) | |
data.article_ids = article_ids | |
data.article_ids_extended_vocab = article_ids_extended_vocab | |
data.articles_oov_lst = oov_lst | |
# for abstract | |
abstract = data.abstract.replace(SENTENCE_START, '').replace(SENTENCE_END, '') | |
abstract_words = abstract.split() | |
data.clean_abstract = ' '.join(abstract_words) | |
abstract_ids, abstract_ids_extended_vocab = abstract2ids(abstract_words, | |
word2id, | |
oov_lst) | |
data.abstract_ids = abstract_ids | |
data.abstract_ids_extended_vocab = abstract_ids_extended_vocab | |
def article2ids(words, word2idx): | |
""" | |
:param words: list of words | |
:param word2idx: dictionary which maps word to index | |
:return: | |
ids: list of indices oov word is represented by UNK id | |
oov_lst : list of oov words from article | |
ids_extended_vocab : list of indices. idx of oov word is | |
represented by vocab_size + index of oov_lst | |
""" | |
ids = [] | |
oov_lst = [] | |
ids_extended_vocab = [] | |
for word in words: | |
try: | |
idx = word2idx[word] | |
ids.append(idx) | |
ids_extended_vocab.append(idx) | |
except KeyError: | |
if word not in oov_lst: | |
oov_lst.append(word) | |
idx_extended = len(word2idx) + oov_lst.index(word) | |
ids.append(word2idx[UNKNOWN_TOKEN]) | |
ids_extended_vocab.append(idx_extended) | |
return ids, ids_extended_vocab, oov_lst | |
def abstract2ids(words, word2idx, article_oovs): | |
ids = [] | |
ids_extended_vocab = [] | |
unk_id = word2idx[UNKNOWN_TOKEN] | |
vocab_size = len(word2idx) | |
for word in words: | |
# oov word | |
if word not in word2idx: | |
ids.append(word2idx[UNKNOWN_TOKEN]) | |
if word in article_oovs: | |
idx = vocab_size + article_oovs.index(word) | |
ids_extended_vocab.append(idx) | |
else: | |
ids_extended_vocab.append(unk_id) | |
else: | |
ids.append(word2idx[word]) | |
ids_extended_vocab.append(word2idx[word]) | |
return ids, ids_extended_vocab | |
def seq2seq(seq, translator): | |
""" Process a sequence into another sequence using a translator. | |
Translate a sequence to another sequence. It can be a sequence of word into | |
a sequence of ID or the opposite. | |
Args: | |
seq (list): Sequence to translate. | |
translator (dict): Dictionary to use to translate the given sequence | |
into another sequence. | |
Returns: | |
list: Translated sequence. | |
""" | |
translated_seq = [] | |
for elem in seq: | |
try: | |
translated_seq.append(translator[elem]) | |
except KeyError: | |
translated_seq.append(translator[UNKNOWN_TOKEN]) | |
translated_seq.append(translator[STOP_TOKEN]) | |
return translated_seq | |
class Data(object): | |
""" Class representing one data. | |
Attributes: | |
article (str): Article to summarize. | |
abstract (str, optional): Corresponding abstract. Defaults to `None`. | |
""" | |
def __init__(self, article, abstract=None): | |
self.article = article | |
self.abstract = abstract | |
self.clean_abstract = None | |
self.article_ids = None | |
self.abstract_ids = None | |
self.articles_oov_lst = None | |
self.article_ids_extended_vocab = None | |
self.abstract_ids_extended_vocab = None | |
class CnnDailymailProcessor(object): | |
""" Class for extracting data. | |
Extract data from the CNN / Daily Mail dataset. | |
""" | |
def get_train_data(self, data_dir): | |
""" Gets a list of `Data` for the train subset. | |
Args: | |
data_dir (str): Location of the folder where the files are located. | |
Returns: | |
list of Data: Extracted Data. | |
""" | |
return self._read_jsonl_to_data(os.path.join(data_dir, "train.jsonl")) | |
def get_val_data(self, data_dir): | |
""" Gets a list of `Data` for the validation subset. | |
Args: | |
data_dir (str): Location of the folder where the files are located. | |
Returns: | |
list of Data: Extracted Data. | |
""" | |
return self._read_jsonl_to_data(os.path.join(data_dir, "val.jsonl")) | |
def get_test_data(self, data_dir): | |
""" Gets a list of `Data` for the test subset. | |
Args: | |
data_dir (str): Location of the folder where the files are located. | |
Returns: | |
list of Data: Extracted Data. | |
""" | |
return self._read_jsonl_to_data(os.path.join(data_dir, "test.jsonl")) | |
def _read_jsonl_to_data(self, input_file): | |
""" Method to read JSONL files and transform it into a Data list. | |
This method reads the content of JSONL files and output a list of | |
corresponding `Data`. | |
Args: | |
input_file (str): Location of the file to read. | |
Returns: | |
list of `Data`: List of Data extracted from the JSONL file. | |
""" | |
data = [] | |
with open(input_file, "r") as f: | |
for line in f: | |
raw_data = json.loads(line) | |
data.append(Data(article=raw_data['article'], | |
abstract=raw_data['abstract'])) | |
return data |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment