Created
June 7, 2022 10:39
-
-
Save deepanshu-yadav/c611d7fd782966064d2cc9a3c2ad3d72 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Here is a class that makes it easy to load data in a batch. | |
class CustomGenerator(): | |
def __init__(self, data_desc, batch_size, scaler): | |
self.data_desc = data_desc | |
self.batch_size = batch_size | |
self.scaler = scaler | |
self.len = self.__len__() # an attribute for the length | |
def __len__(self): | |
return len(self.data_desc) // self.batch_size | |
def __getitem__(self, idx): | |
"""Gives a batch of training or validation data.""" | |
batch_x = self.data_desc[idx * self.batch_size: (idx + 1) * self.batch_size] | |
file = batch_x[0][0] | |
label = int(batch_x[0][1]) | |
offset = int(batch_x[0][2]) | |
new_data_file = np.load(file) | |
data_x = new_data_file[offset, :] | |
for ind in range(1, self.batch_size): | |
new_file = batch_x[ind][0] | |
label = int(batch_x[ind][1]) | |
offset = int(batch_x[ind][2]) | |
if new_file != file: | |
new_data_file = np.load(new_file) | |
file = new_file | |
new_data = new_data_file[offset, :] | |
data_x = np.vstack((data_x, new_data)) | |
data_x = self.scaler.transform(data_x) | |
data_x = data_x.astype('float32') | |
return data_x, data_x | |
def getitem(self, index): | |
return self.__getitem__(index) | |
def get_prediction_data(self, idx): | |
"""Used for testing the model.""" | |
batch_x = self.data_desc[idx * self.batch_size: (idx + 1) * self.batch_size] | |
file = batch_x[0][0] | |
label = int(batch_x[0][1]) | |
offset = int(batch_x[0][2]) | |
new_data_file = np.load(file) | |
data_x = new_data_file[offset, :] | |
data_y = np.array([label]) | |
for ind in range(1, self.batch_size): | |
new_file = batch_x[ind][0] | |
label = int(batch_x[ind][1]) | |
offset = int(batch_x[ind][2]) | |
if new_file != file: | |
new_data_file = np.load(new_file) | |
file = new_file | |
new_data = new_data_file[offset, :] | |
data_x = np.vstack((data_x, new_data)) | |
data_y = np.vstack((data_y, np.array([label]))) | |
data_x = self.scaler.transform(data_x) | |
data_x = data_x.astype('float32') | |
data_y = data_y.astype('float32') | |
return data_x, data_y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment