Skip to content

Instantly share code, notes, and snippets.

@deepanshu-yadav
Created June 7, 2022 10:39
Show Gist options
  • Save deepanshu-yadav/c611d7fd782966064d2cc9a3c2ad3d72 to your computer and use it in GitHub Desktop.
Save deepanshu-yadav/c611d7fd782966064d2cc9a3c2ad3d72 to your computer and use it in GitHub Desktop.
# Here is a class that makes it easy to load data in a batch.
class CustomGenerator():
def __init__(self, data_desc, batch_size, scaler):
self.data_desc = data_desc
self.batch_size = batch_size
self.scaler = scaler
self.len = self.__len__() # an attribute for the length
def __len__(self):
return len(self.data_desc) // self.batch_size
def __getitem__(self, idx):
"""Gives a batch of training or validation data."""
batch_x = self.data_desc[idx * self.batch_size: (idx + 1) * self.batch_size]
file = batch_x[0][0]
label = int(batch_x[0][1])
offset = int(batch_x[0][2])
new_data_file = np.load(file)
data_x = new_data_file[offset, :]
for ind in range(1, self.batch_size):
new_file = batch_x[ind][0]
label = int(batch_x[ind][1])
offset = int(batch_x[ind][2])
if new_file != file:
new_data_file = np.load(new_file)
file = new_file
new_data = new_data_file[offset, :]
data_x = np.vstack((data_x, new_data))
data_x = self.scaler.transform(data_x)
data_x = data_x.astype('float32')
return data_x, data_x
def getitem(self, index):
return self.__getitem__(index)
def get_prediction_data(self, idx):
"""Used for testing the model."""
batch_x = self.data_desc[idx * self.batch_size: (idx + 1) * self.batch_size]
file = batch_x[0][0]
label = int(batch_x[0][1])
offset = int(batch_x[0][2])
new_data_file = np.load(file)
data_x = new_data_file[offset, :]
data_y = np.array([label])
for ind in range(1, self.batch_size):
new_file = batch_x[ind][0]
label = int(batch_x[ind][1])
offset = int(batch_x[ind][2])
if new_file != file:
new_data_file = np.load(new_file)
file = new_file
new_data = new_data_file[offset, :]
data_x = np.vstack((data_x, new_data))
data_y = np.vstack((data_y, np.array([label])))
data_x = self.scaler.transform(data_x)
data_x = data_x.astype('float32')
data_y = data_y.astype('float32')
return data_x, data_y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment