Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created September 20, 2022 17:21
Show Gist options
  • Save ariG23498/89fd02dbb7974f41e51787703bfb7408 to your computer and use it in GitHub Desktop.
Save ariG23498/89fd02dbb7974f41e51787703bfb7408 to your computer and use it in GitHub Desktop.
Jax numpy input pipeline
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/89fd02dbb7974f41e51787703bfb7408/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lIYdn1woOS1n",
"outputId": "14595a2e-d77b-4831-f8a9-49a355d1cda0"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.8.2\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"\n",
"print(tf.__version__)"
]
},
{
"cell_type": "code",
"source": [
"# DATA\n",
"BUFFER_SIZE = 1024\n",
"BATCH_SIZE = 256\n",
"AUTO = tf.data.AUTOTUNE\n",
"INPUT_SHAPE = (32, 32, 3)\n",
"NUM_CLASSES = 10\n",
"\n",
"\n",
"# AUGMENTATION\n",
"IMAGE_SIZE = 48 # We will resize input images to this size.\n",
"PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.\n",
"NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2\n",
"MASK_PROPORTION = 0.75 # We have found 75% masking to give us the best results."
],
"metadata": {
"id": "uTFJYKkt5nhu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
"(x_train, y_train), (x_val, y_val) = (\n",
" (x_train[:40000], y_train[:40000]),\n",
" (x_train[40000:], y_train[40000:]),\n",
")\n",
"print(f\"Training samples: {len(x_train)}\")\n",
"print(f\"Validation samples: {len(x_val)}\")\n",
"print(f\"Testing samples: {len(x_test)}\")\n",
"\n",
"train_ds = tf.data.Dataset.from_tensor_slices(x_train)\n",
"train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
"val_ds = tf.data.Dataset.from_tensor_slices(x_val)\n",
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
"test_ds = tf.data.Dataset.from_tensor_slices(x_test)\n",
"test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EX8tqut-5Ulf",
"outputId": "75f4d377-97f8-40ec-9fec-363954ad4bee"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training samples: 40000\n",
"Validation samples: 10000\n",
"Testing samples: 10000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"augmentations = keras.Sequential(\n",
" [\n",
" layers.Rescaling(1 / 255.0),\n",
" layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),\n",
" layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),\n",
" layers.RandomFlip(\"horizontal\"),\n",
" ],\n",
" name=\"train_data_augmentation\",\n",
")"
],
"metadata": {
"id": "FVxK6lgJ5Yrw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def map_fn(images):\n",
" return augmentations(images)"
],
"metadata": {
"id": "l5Jvif2Q6NNL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"aug_train_ds = train_ds.map(map_fn)"
],
"metadata": {
"id": "xcMYR3nF6Fjb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"image_batch = next(iter(aug_train_ds))\n",
"image_batch.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VNO8IHdP6TAn",
"outputId": "45ca3a5d-c953-4abb-c95b-1115a45b3370"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([256, 48, 48, 3])"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"type(image_batch)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "l8UFzGca6mQU",
"outputId": "7e04089d-6a4f-4f37-c15a-4931a770ce18"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensorflow.python.framework.ops.EagerTensor"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [
"aug_train_ds_numpy = aug_train_ds.as_numpy_iterator()"
],
"metadata": {
"id": "hSDzHuWd6YGO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"image_batch = next(iter(aug_train_ds_numpy))\n",
"image_batch.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KX5HW8186cSm",
"outputId": "f65abc4f-901e-4617-f024-34680e7a317d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(256, 48, 48, 3)"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [
"type(image_batch)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eJRf-1R06q9m",
"outputId": "b147b3c0-7c60-4f8b-caca-c5484005d91b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"metadata": {},
"execution_count": 23
}
]
}
],
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment