Created
July 22, 2020 11:26
-
-
Save stefanthaler/0289559c947b0d1789e0308d45f10d1c to your computer and use it in GitHub Desktop.
Tensorflow DataSet transformation that groups sequential data into buckets and truncates them instead of padding.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
import tensorflow as tf | |
import numpy as np | |
from tensorflow.python.framework import ops | |
from tensorflow.python.data.ops import dataset_ops | |
from tensorflow.python.util.tf_export import tf_export | |
from tensorflow.python.data.experimental import group_by_window | |
from tensorflow.python.framework import constant_op | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.ops import array_ops | |
from tensorflow.python.ops import check_ops | |
from tensorflow.python.ops import math_ops | |
def buckettrunc_by_sequence_length( | |
element_length_func, | |
bucket_boundaries, | |
bucket_batch_sizes, | |
drop_remainder=False): | |
# Map function | |
def element_to_bucket_id(*args): | |
"""Return int64 id of the length bucket for this element.""" | |
bucket_boundaries=[10, 15, 20] | |
seq_length = element_length_func(args) | |
err_msg = ("Sequence length (%i) needs to be greater then the first bucket boundary (%i) ."%(seq_length, bucket_boundaries[0]) ) | |
tf.assert_greater( | |
tf.constant(seq_length, dtype=tf.dtypes.int64), | |
tf.constant(bucket_boundaries[0], dtype=tf.dtypes.int64), | |
message=err_msg) | |
boundaries = sorted(list(bucket_boundaries)) # [10, 15, 20] | |
buckets_min = boundaries # [10, 15, 20] | |
buckets_max = boundaries[1:] + [np.iinfo(np.int32).max] # [15, 20, np.int.max] | |
conditions_c = math_ops.logical_and( # for each element, | |
math_ops.greater_equal(x=seq_length, y=buckets_min), # x >= y | |
math_ops.less(x=seq_length, y=buckets_max )) # x < y | |
bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) | |
return bucket_id | |
# Reduce function | |
def batching_fn(bucket_id, grouped_dataset): | |
batch_size = window_size_fn(bucket_id) | |
boundaries = tf.constant(bucket_boundaries, dtype=tf.dtypes.int64) | |
bucket_boundary = boundaries[bucket_id] | |
begin = tf.constant(value=0, dtype=tf.dtypes.int64,name='seq_begin') | |
grouped_dataset = grouped_dataset.map(lambda seq: tf.slice(seq, begin=[begin], size=[bucket_boundary])) # truncate to bucket boundary | |
return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder) | |
# Batch size functions | |
batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.dtypes.int64) | |
def window_size_fn(bucket_id): | |
window_size = batch_sizes[bucket_id] | |
return window_size | |
def _apply_fn(dataset): | |
return dataset.apply(group_by_window( | |
key_func=element_to_bucket_id, | |
reduce_func=batching_fn, | |
window_size_func=window_size_fn) | |
) | |
return _apply_fn | |
def seq_len(seq): | |
return tf.shape(seq)[0] | |
# data generator | |
def gen(): | |
for i in [np.array([1, 1, 1]), np.array([2, 2, 2, 2, 2]), np.array([3, 3, 3, 3, 3, 3, 3])]: | |
yield i | |
# data pipeline | |
dataset = tf.data.Dataset.from_generator( gen, (tf.int32), (tf.TensorShape([None]))) | |
dataset = dataset.apply( buckettrunc_by_sequence_length( | |
element_length_func=seq_len, | |
bucket_boundaries=[3,7], | |
bucket_batch_sizes=[2,2], | |
drop_remainder=False )) | |
list(dataset.take(3).as_numpy_iterator()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment