Last active
October 7, 2019 07:21
-
-
Save aleksas/55e600108a62d87f2d1802f810632b10 to your computer and use it in GitHub Desktop.
A tensor2tensor text multi labeling problem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2019 The Tensor2Tensor Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""IMDB Sentiment Classification Problem.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import tarfile | |
from tensor2tensor.data_generators import generator_utils | |
from tensor2tensor.data_generators import problem | |
from tensor2tensor.data_generators import text_problems | |
from tensor2tensor.utils import registry | |
import tensorflow as tf | |
@registry.register_problem | |
class SentimentIMDB(text_problems.Text2MultiLabelProblem): | |
"""IMDB sentiment classification.""" | |
URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" | |
@property | |
def is_generate_per_split(self): | |
return True | |
@property | |
def dataset_splits(self): | |
return [{ | |
"split": problem.DatasetSplit.TRAIN, | |
"shards": 10, | |
}, { | |
"split": problem.DatasetSplit.EVAL, | |
"shards": 1, | |
}] | |
@property | |
def approx_vocab_size(self): | |
return 2**13 # 8k vocab suffices for this small dataset. | |
@property | |
def num_classes(self): | |
return 2 | |
def class_labels(self, data_dir): | |
del data_dir | |
return ["neg", "pos"] | |
def doc_generator(self, imdb_dir, dataset, include_label=False): | |
dirs = [(os.path.join(imdb_dir, dataset, "pos"), True), (os.path.join( | |
imdb_dir, dataset, "neg"), False)] | |
for d, labels in dirs: | |
for filename in os.listdir(d): | |
with tf.gfile.Open(os.path.join(d, filename)) as imdb_f: | |
doc = imdb_f.read().strip() | |
if include_label: | |
yield doc, labels.split(' ') #split by space char | |
else: | |
yield doc | |
def generate_samples(self, data_dir, tmp_dir, dataset_split): | |
"""Generate examples.""" | |
# Download and extract | |
compressed_filename = os.path.basename(self.URL) | |
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename, | |
self.URL) | |
imdb_dir = os.path.join(tmp_dir, "aclImdb") | |
if not tf.gfile.Exists(imdb_dir): | |
with tarfile.open(download_path, "r:gz") as tar: | |
tar.extractall(tmp_dir) | |
# Generate examples | |
train = dataset_split == problem.DatasetSplit.TRAIN | |
dataset = "train" if train else "test" | |
for doc, labels in self.doc_generator(imdb_dir, dataset, include_label=True): | |
yield { | |
"inputs": doc, | |
"label": [int(label) for label in labels], | |
} | |
@registry.register_problem | |
class SentimentIMDBCharacters(SentimentIMDB): | |
"""IMDB sentiment classification, character level.""" | |
@property | |
def vocab_type(self): | |
return text_problems.VocabType.CHARACTER | |
def global_task_id(self): | |
return problem.TaskID.EN_CHR_SENT |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Text2MultiLabelProblem(Text2TextProblem): | |
"""Base class for text multi-labeling problems.""" | |
def generate_samples(self, data_dir, tmp_dir, dataset_split): | |
"""Generate samples of text and label pairs. | |
Each yielded dict will be a single example. The inputs should be raw text. | |
The label should be an int array with each element in [0, self.num_classes). | |
Args: | |
data_dir: final data directory. Typically only used in this method to copy | |
over user-supplied vocab files (for example, if vocab_type == | |
VocabType.TOKEN). | |
tmp_dir: temporary directory that you can use for downloading and scratch. | |
dataset_split: problem.DatasetSplit, which data split to generate samples | |
for (for example, training and evaluation). | |
Yields: | |
{"inputs": text, "labels": int[]} | |
""" | |
raise NotImplementedError() | |
# START: Additional subclass interface | |
@property | |
def num_classes(self): | |
"""The number of classes.""" | |
raise NotImplementedError() | |
def class_labels(self, data_dir): | |
"""String representation of the classes.""" | |
del data_dir | |
return ["ID_%d" % i for i in range(self.num_classes)] | |
# END: Additional subclass interface | |
def generate_text_for_vocab(self, data_dir, tmp_dir): | |
for i, sample in enumerate( | |
self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)): | |
yield sample["inputs"] | |
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab: | |
break | |
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): | |
generator = self.generate_samples(data_dir, tmp_dir, dataset_split) | |
encoder = self.get_or_create_vocab(data_dir, tmp_dir) | |
for sample in generator: | |
inputs = encoder.encode(sample["inputs"]) | |
inputs.append(text_encoder.EOS_ID) | |
labels = sample["labels"] | |
yield {"inputs": inputs, "targets": labels} | |
def feature_encoders(self, data_dir): | |
encoder = self.get_or_create_vocab(data_dir, None, force_get=True) | |
return { | |
"inputs": encoder, | |
"targets": [text_encoder.ClassLabelEncoder(label) for label in self.class_labels(data_dir)] | |
} | |
def hparams(self, defaults, unused_model_hparams): | |
p = defaults | |
p.modality = {"inputs": modalities.ModalityType.SYMBOL, | |
"targets": modalities.ModalityType.MULTI_LABEL} | |
p.vocab_size = {"inputs": self._encoders["inputs"].vocab_size, | |
"targets": self.num_classes} | |
def example_reading_spec(self): | |
data_fields = { | |
"inputs": tf.VarLenFeature(tf.int64), | |
"targets": tf.FixedLenFeature((), tf.int64), | |
} | |
data_items_to_decoders = None | |
return (data_fields, data_items_to_decoders) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment