Created
June 7, 2022 10:36
-
-
Save deepanshu-yadav/dfeaf0a8fffc46978bab35af45a46f5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
from sklearn.preprocessing import MinMaxScaler | |
training_files = glob.glob(os.path.join(train_dir, '*')) | |
validation_files = glob.glob(os.path.join(validation_dir, '*')) | |
min_max_scaler_train = MinMaxScaler() | |
# constants declaration. Notice the constants are same for meaningful comparison. | |
BATCH_SIZE = 32 | |
NO_OF_EPOCHS = 3 | |
train_data_descriptor = np.empty((1,3)) | |
for train_file_name in training_files: | |
file_np = np.load(train_file_name) | |
rows = file_np.shape[0] | |
labels = np.repeat(0, rows) | |
locations = np.repeat(train_file_name, rows) | |
offsets = np.arange(0, rows, 1, dtype=int) | |
# we have created labels, locations and offsets. | |
# We now combine all these to form a single data descriptor. | |
single_file_data = np.vstack((locations, labels, offsets)).T | |
# Now this single file descriptor is added to our list of descriptor. | |
train_data_descriptor = np.vstack((train_data_descriptor, single_file_data)) | |
# we will fit the min max scaler to training data here. | |
# Notice we will use partial fit for large data. | |
min_max_scaler_train.partial_fit(file_np) | |
# Leave the file descriptor out. It is empty. | |
train_data_descriptor = train_data_descriptor[1:, :].tolist() | |
# Repeat the same process for validation. | |
validation_data_descriptor = np.empty((1,3)) | |
for validation_file_name in validation_files: | |
file_np = np.load(validation_file_name) | |
rows = file_np.shape[0] | |
labels = np.repeat(0, rows) | |
locations = np.repeat(validation_file_name, rows) | |
offsets = np.arange(0, rows, 1, dtype=int) | |
single_file_data = np.vstack((locations, labels, offsets)).T | |
validation_data_descriptor = np.vstack((validation_data_descriptor, single_file_data)) | |
validation_data_descriptor = validation_data_descriptor[1:, :].tolist() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment