Skip to content

Instantly share code, notes, and snippets.

@aleksas
Last active October 7, 2019 07:21
Show Gist options
  • Save aleksas/55e600108a62d87f2d1802f810632b10 to your computer and use it in GitHub Desktop.
Save aleksas/55e600108a62d87f2d1802f810632b10 to your computer and use it in GitHub Desktop.
A tensor2tensor text multi labeling problem
# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""IMDB Sentiment Classification Problem."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
import tensorflow as tf
@registry.register_problem
class SentimentIMDB(text_problems.Text2MultiLabelProblem):
"""IMDB sentiment classification."""
URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
@property
def is_generate_per_split(self):
return True
@property
def dataset_splits(self):
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 10,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
@property
def approx_vocab_size(self):
return 2**13 # 8k vocab suffices for this small dataset.
@property
def num_classes(self):
return 2
def class_labels(self, data_dir):
del data_dir
return ["neg", "pos"]
def doc_generator(self, imdb_dir, dataset, include_label=False):
dirs = [(os.path.join(imdb_dir, dataset, "pos"), True), (os.path.join(
imdb_dir, dataset, "neg"), False)]
for d, labels in dirs:
for filename in os.listdir(d):
with tf.gfile.Open(os.path.join(d, filename)) as imdb_f:
doc = imdb_f.read().strip()
if include_label:
yield doc, labels.split(' ') #split by space char
else:
yield doc
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate examples."""
# Download and extract
compressed_filename = os.path.basename(self.URL)
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
self.URL)
imdb_dir = os.path.join(tmp_dir, "aclImdb")
if not tf.gfile.Exists(imdb_dir):
with tarfile.open(download_path, "r:gz") as tar:
tar.extractall(tmp_dir)
# Generate examples
train = dataset_split == problem.DatasetSplit.TRAIN
dataset = "train" if train else "test"
for doc, labels in self.doc_generator(imdb_dir, dataset, include_label=True):
yield {
"inputs": doc,
"label": [int(label) for label in labels],
}
@registry.register_problem
class SentimentIMDBCharacters(SentimentIMDB):
"""IMDB sentiment classification, character level."""
@property
def vocab_type(self):
return text_problems.VocabType.CHARACTER
def global_task_id(self):
return problem.TaskID.EN_CHR_SENT
class Text2MultiLabelProblem(Text2TextProblem):
"""Base class for text multi-labeling problems."""
def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate samples of text and label pairs.
Each yielded dict will be a single example. The inputs should be raw text.
The label should be an int array with each element in [0, self.num_classes).
Args:
data_dir: final data directory. Typically only used in this method to copy
over user-supplied vocab files (for example, if vocab_type ==
VocabType.TOKEN).
tmp_dir: temporary directory that you can use for downloading and scratch.
dataset_split: problem.DatasetSplit, which data split to generate samples
for (for example, training and evaluation).
Yields:
{"inputs": text, "labels": int[]}
"""
raise NotImplementedError()
# START: Additional subclass interface
@property
def num_classes(self):
"""The number of classes."""
raise NotImplementedError()
def class_labels(self, data_dir):
"""String representation of the classes."""
del data_dir
return ["ID_%d" % i for i in range(self.num_classes)]
# END: Additional subclass interface
def generate_text_for_vocab(self, data_dir, tmp_dir):
for i, sample in enumerate(
self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)):
yield sample["inputs"]
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab:
break
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
for sample in generator:
inputs = encoder.encode(sample["inputs"])
inputs.append(text_encoder.EOS_ID)
labels = sample["labels"]
yield {"inputs": inputs, "targets": labels}
def feature_encoders(self, data_dir):
encoder = self.get_or_create_vocab(data_dir, None, force_get=True)
return {
"inputs": encoder,
"targets": [text_encoder.ClassLabelEncoder(label) for label in self.class_labels(data_dir)]
}
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.modality = {"inputs": modalities.ModalityType.SYMBOL,
"targets": modalities.ModalityType.MULTI_LABEL}
p.vocab_size = {"inputs": self._encoders["inputs"].vocab_size,
"targets": self.num_classes}
def example_reading_spec(self):
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.FixedLenFeature((), tf.int64),
}
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment