Created
September 20, 2022 17:21
-
-
Save ariG23498/89fd02dbb7974f41e51787703bfb7408 to your computer and use it in GitHub Desktop.
Jax numpy input pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/89fd02dbb7974f41e51787703bfb7408/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "lIYdn1woOS1n", | |
"outputId": "14595a2e-d77b-4831-f8a9-49a355d1cda0" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2.8.2\n" | |
] | |
} | |
], | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow import keras\n", | |
"from tensorflow.keras import layers\n", | |
"\n", | |
"print(tf.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# DATA\n", | |
"BUFFER_SIZE = 1024\n", | |
"BATCH_SIZE = 256\n", | |
"AUTO = tf.data.AUTOTUNE\n", | |
"INPUT_SHAPE = (32, 32, 3)\n", | |
"NUM_CLASSES = 10\n", | |
"\n", | |
"\n", | |
"# AUGMENTATION\n", | |
"IMAGE_SIZE = 48 # We will resize input images to this size.\n", | |
"PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.\n", | |
"NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2\n", | |
"MASK_PROPORTION = 0.75 # We have found 75% masking to give us the best results." | |
], | |
"metadata": { | |
"id": "uTFJYKkt5nhu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n", | |
"(x_train, y_train), (x_val, y_val) = (\n", | |
" (x_train[:40000], y_train[:40000]),\n", | |
" (x_train[40000:], y_train[40000:]),\n", | |
")\n", | |
"print(f\"Training samples: {len(x_train)}\")\n", | |
"print(f\"Validation samples: {len(x_val)}\")\n", | |
"print(f\"Testing samples: {len(x_test)}\")\n", | |
"\n", | |
"train_ds = tf.data.Dataset.from_tensor_slices(x_train)\n", | |
"train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)\n", | |
"\n", | |
"val_ds = tf.data.Dataset.from_tensor_slices(x_val)\n", | |
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)\n", | |
"\n", | |
"test_ds = tf.data.Dataset.from_tensor_slices(x_test)\n", | |
"test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "EX8tqut-5Ulf", | |
"outputId": "75f4d377-97f8-40ec-9fec-363954ad4bee" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Training samples: 40000\n", | |
"Validation samples: 10000\n", | |
"Testing samples: 10000\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"augmentations = keras.Sequential(\n", | |
" [\n", | |
" layers.Rescaling(1 / 255.0),\n", | |
" layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),\n", | |
" layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),\n", | |
" layers.RandomFlip(\"horizontal\"),\n", | |
" ],\n", | |
" name=\"train_data_augmentation\",\n", | |
")" | |
], | |
"metadata": { | |
"id": "FVxK6lgJ5Yrw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def map_fn(images):\n", | |
" return augmentations(images)" | |
], | |
"metadata": { | |
"id": "l5Jvif2Q6NNL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"aug_train_ds = train_ds.map(map_fn)" | |
], | |
"metadata": { | |
"id": "xcMYR3nF6Fjb" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"image_batch = next(iter(aug_train_ds))\n", | |
"image_batch.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "VNO8IHdP6TAn", | |
"outputId": "45ca3a5d-c953-4abb-c95b-1115a45b3370" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"TensorShape([256, 48, 48, 3])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"type(image_batch)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "l8UFzGca6mQU", | |
"outputId": "7e04089d-6a4f-4f37-c15a-4931a770ce18" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"tensorflow.python.framework.ops.EagerTensor" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 20 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"aug_train_ds_numpy = aug_train_ds.as_numpy_iterator()" | |
], | |
"metadata": { | |
"id": "hSDzHuWd6YGO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"image_batch = next(iter(aug_train_ds_numpy))\n", | |
"image_batch.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "KX5HW8186cSm", | |
"outputId": "f65abc4f-901e-4617-f024-34680e7a317d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(256, 48, 48, 3)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 22 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"type(image_batch)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "eJRf-1R06q9m", | |
"outputId": "b147b3c0-7c60-4f8b-caca-c5484005d91b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"numpy.ndarray" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 23 | |
} | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"name": "scratchpad", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment