Last active
June 3, 2023 05:20
-
-
Save rameshKrSah/c6dea6fada460be48499e2ff43995629 to your computer and use it in GitHub Desktop.
Create a TensorFlow dataset from X, Y arrays.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.model_selection import train_test_split | |
import numpy as np | |
import tensorflow as tf | |
def create_dataset(X, Y, batch_size): | |
""" Create train and test TF dataset from X and Y | |
The prefetch overlays the preprocessing and model execution of a training step. | |
While the model is executing training step s, the input pipeline is reading the data for step s+1. | |
AUTOTUNE automatically tune the number for sample which are prefeteched automatically. | |
Keyword arguments: | |
X -- numpy array | |
Y -- numpy array | |
batch_size -- integer | |
""" | |
AUTOTUNE = tf.data.experimental.AUTOTUNE | |
X = X.astype('float32') | |
Y = Y.astype('float32') | |
x_tr, x_ts, y_tr, y_ts = train_test_split(X, Y, test_size = 0.2, random_state=42, stratify=Y, shuffle=True) | |
train_dataset = tf.data.Dataset.from_tensor_slices((x_tr, y_tr)) | |
train_dataset = train_dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True) | |
train_dataset = train_dataset.batch(batch_size).prefetch(AUTOTUNE) | |
test_dataset = tf.data.Dataset.from_tensor_slices((x_ts, y_ts)) | |
test_dataset = test_dataset.batch(batch_size).prefetch(AUTOTUNE) | |
return train_dataset, test_dataset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
made public