Last active
March 4, 2019 12:54
-
-
Save wallneradam/9263a90b55756d882b500515764271c7 to your computer and use it in GitHub Desktop.
Keras Multi Step Timeseries Generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from tensorflow.python.keras import utils | |
class MultiStepTimeseriesGenerator(utils.Sequence): | |
"""Utility class for generating batches of temporal data. | |
This class takes in a sequence of data-points gathered at | |
equal intervals, along with time series parameters such as | |
stride, length of history, etc., to produce batches for | |
training/validation. | |
# Arguments | |
data: Indexable generator (such as list or Numpy array) | |
containing consecutive data points (timesteps). | |
The data should be at 2D, and axis 0 is expected | |
to be the time dimension. | |
targets: Targets corresponding to timesteps in `data`. | |
It should have same length as `data`. | |
length: Length of the output sequences (in number of timesteps). | |
target_length: Length of target sequence (in number of timesteps). | |
sampling_rate: Period between successive individual timesteps | |
within sequences. For rate `r`, timesteps | |
`data[i]`, `data[i-r]`, ... `data[i - length]` | |
are used for create a sample sequence. | |
stride: Period between successive output sequences. | |
For stride `s`, consecutive output samples would | |
be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. | |
start_index: Data points earlier than `start_index` will not be used | |
in the output sequences. This is useful to reserve part of the | |
data for test or validation. | |
end_index: Data points later than `end_index` will not be used | |
in the output sequences. This is useful to reserve part of the | |
data for test or validation. | |
shuffle: Whether to shuffle output samples, | |
or instead draw them in chronological order. | |
reverse: Boolean: if `true`, timesteps in each output sample will be | |
in reverse chronological order. | |
reverse_target: Boolean: if `true`, timesteps in each target sample will be | |
in reverse chronological order. | |
batch_size: Number of timeseries samples in each batch | |
(except maybe the last one). | |
# Returns | |
A [Sequence](/utils/#sequence) instance. | |
""" | |
def __init__(self, data, targets, length=1, target_length=1, | |
sampling_rate=1, | |
stride=1, | |
start_index=0, | |
end_index=None, | |
shuffle=False, | |
reverse=False, reverse_target=False, | |
batch_size=128): | |
if len(data) != len(targets): | |
raise ValueError('Data and targets have to be' + | |
' of same length. ' | |
'Data length is {}'.format(len(data)) + | |
' while target length is {}'.format(len(targets))) | |
self.data = data | |
self.targets = targets | |
self.length = length | |
self.target_length = target_length | |
self.sampling_rate = sampling_rate | |
self.stride = stride | |
self.start_index = start_index + length | |
if end_index is None: | |
end_index = len(data) - target_length | |
self.end_index = end_index | |
self.shuffle = shuffle | |
self.reverse = reverse | |
self.reverse_target = reverse_target | |
self.batch_size = batch_size | |
if self.start_index > self.end_index: | |
raise ValueError('`start_index+length=%i > end_index=%i` ' | |
'is disallowed, as no part of the sequence ' | |
'would be left to be used as current step.' | |
% (self.start_index, self.end_index)) | |
def __len__(self): | |
return (self.end_index - self.start_index + self.batch_size * self.stride) // (self.batch_size * self.stride) | |
def __getitem__(self, index): | |
if self.shuffle: | |
rows = np.random.randint( | |
self.start_index, self.end_index + 1, size=self.batch_size) | |
else: | |
i = self.start_index + self.batch_size * self.stride * index | |
rows = np.arange(i, min(i + self.batch_size * self.stride, self.end_index + 1), self.stride) | |
samples = np.array([self.data[row - self.length:row:self.sampling_rate] for row in rows]) | |
targets = np.array([self.targets[row:row + self.target_length] for row in rows]) | |
if self.reverse: | |
samples = samples[:, ::-1, ...] | |
if self.reverse_target: | |
targets = targets[:, ::-1, ...] | |
return samples, targets |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment