Created
August 4, 2024 12:46
-
-
Save NobuoTsukamoto/094c4953947845e270cd0dcc093c50bd to your computer and use it in GitHub Desktop.
jax_flax_optax_multisteps.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMqG9rimDxInY/wQxBADHEG", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/NobuoTsukamoto/094c4953947845e270cd0dcc093c50bd/jax_flax_optax_multisteps.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## `optax.MultiSteps`を使用したGradient Accumulation\n", | |
"\n", | |
"Jax, Flax, optaxを利用して、Gradient Accumulationを実現したい。 \n", | |
"参考\n", | |
"- https://optax.readthedocs.io/en/latest/_collections/examples/gradient_accumulation.html\n", | |
"- https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html\n", | |
"- https://github.com/google/flax/tree/main/examples/imagenet" | |
], | |
"metadata": { | |
"id": "FHzpv9IFZGUk" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"https://github.com/google/flax/tree/main/examples/imagenet \n", | |
"を参考に、mnistを使った最小のサンプルを作成し、実行する。 \n", | |
"https://github.com/NobuoTsukamoto/jax_examples.git \n", | |
"の`mnist_gradient_accumulation`フォルダにサンプルを作成した。" | |
], | |
"metadata": { | |
"id": "_LC0BSVdaWKV" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "AzApj1rrHmuO", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "84edc4e5-db79-409e-cfdb-055060a54e7e" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Cloning into 'jax_examples'...\n", | |
"remote: Enumerating objects: 928, done.\u001b[K\n", | |
"remote: Counting objects: 100% (315/315), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (139/139), done.\u001b[K\n", | |
"remote: Total 928 (delta 241), reused 232 (delta 173), pack-reused 613\u001b[K\n", | |
"Receiving objects: 100% (928/928), 4.31 MiB | 22.65 MiB/s, done.\n", | |
"Resolving deltas: 100% (672/672), done.\n", | |
"/content/jax_examples/mnist_gradient_accumulation\n" | |
] | |
} | |
], | |
"source": [ | |
"!git clone https://github.com/NobuoTsukamoto/jax_examples.git\n", | |
"%cd jax_examples/mnist_gradient_accumulation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"必要なインストールパッケージ" | |
], | |
"metadata": { | |
"id": "6zJh3vxwZxK7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install -qq clu" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PFbclpJhRC1O", | |
"outputId": "269552c2-6093-4e94-f844-8baa4fdc9f43" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/77.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.8/101.8 kB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h Building wheel for ml-collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Gradient accumulationを行わない場合\n", | |
"\n", | |
"- Batch size : 32\n", | |
"- Gradient accumulation steps : 1" | |
], | |
"metadata": { | |
"id": "pJb0Xb1oZ0c1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!python main.py \\\n", | |
" --config config/default.py \\\n", | |
" --config.batch_size=32 \\\n", | |
" --config.gradient_accumulation_steps=1 \\\n", | |
" --config.num_train_steps=10 \\\n", | |
" --config.log_every_steps=1" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "-3eA9wdqHqiT", | |
"outputId": "0f2fd614-099b-494c-96d1-1bf74e64a7a8" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2024-08-04 12:44:58.079996: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", | |
"2024-08-04 12:44:58.127668: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", | |
"2024-08-04 12:44:58.147567: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", | |
"2024-08-04 12:45:01.900626: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", | |
"I0804 12:45:06.662463 135447274422272 xla_bridge.py:863] Unable to initialize backend 'cuda': jaxlib/cuda/versions_helpers.cc:98: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA libraries.\n", | |
"I0804 12:45:06.669809 135447274422272 xla_bridge.py:863] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: \"rocm\". Available platform names are: CUDA\n", | |
"I0804 12:45:06.671079 135447274422272 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", | |
"I0804 12:45:06.672401 135447274422272 main.py:32] JAX process: 0 / 1\n", | |
"I0804 12:45:06.672538 135447274422272 main.py:33] JAX local devices: [CpuDevice(id=0)]\n", | |
"I0804 12:45:06.672963 135447274422272 local.py:45] Setting task status: process_index: 0, process_count: 1\n", | |
"2024-08-04 12:45:07.967792: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with \"NOT_FOUND: Could not locate the credentials file.\". Retrieving token from GCE failed with \"NOT_FOUND: Error executing an HTTP request: HTTP response code 404\".\n", | |
"I0804 12:45:08.123267 135447274422272 dataset_info.py:805] Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: mnist/3.0.1\n", | |
"I0804 12:45:08.307303 135447274422272 dataset_info.py:617] Load dataset info from /tmp/tmpxzuwcccstfds\n", | |
"I0804 12:45:08.311540 135447274422272 dataset_info.py:709] For 'mnist/3.0.1': fields info.[citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.\n", | |
"I0804 12:45:08.312084 135447274422272 dataset_builder.py:644] Generating dataset mnist (/workdir/tensorflow_datasets/mnist/3.0.1)\n", | |
"\u001b[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /workdir/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n", | |
"I0804 12:45:08.469943 135447274422272 dataset_builder.py:693] Dataset mnist is hosted on GCS. It will automatically be downloaded to your\n", | |
"local data directory. If you'd instead prefer to read directly from our public\n", | |
"GCS bucket (recommended if you're running on GCP), you can instead pass\n", | |
"`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.\n", | |
"\n", | |
"Dl Completed...: 100% 5/5 [00:00<00:00, 8.87 file/s]\n", | |
"I0804 12:45:09.081179 135447274422272 dataset_info.py:617] Load dataset info from /workdir/tensorflow_datasets/mnist/incomplete.DRFETD_3.0.1/\n", | |
"I0804 12:45:09.084620 135447274422272 dataset_info.py:709] For 'mnist/3.0.1': fields info.[citation, splits, supervised_keys, module_name, file_format] differ on disk and in the code. Keeping the one from code.\n", | |
"\u001b[1mDataset mnist downloaded and prepared to /workdir/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001b[0m\n", | |
"I0804 12:45:09.087499 135447274422272 reader.py:261] Creating a tf.data.Dataset reading 1 files located in folders: /workdir/tensorflow_datasets/mnist/3.0.1.\n", | |
"I0804 12:45:09.224036 135447274422272 logging_logger.py:49] Constructing tf.data.Dataset mnist for split train[0:60000], from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:09.406544 135447274422272 reader.py:261] Creating a tf.data.Dataset reading 1 files located in folders: /workdir/tensorflow_datasets/mnist/3.0.1.\n", | |
"I0804 12:45:09.491970 135447274422272 logging_logger.py:49] Constructing tf.data.Dataset mnist for split test[0:10000], from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:09.596029 135447274422272 train.py:221] Steps per epech : 1876, Step per eval : 313, Num steps : 10\n", | |
"I0804 12:45:10.905066 135447274422272 train.py:159] Batch size: 32, Gradient accumulation steps : 1\n", | |
"I0804 12:45:11.085379 135447274422272 train.py:249] Initial compilation, this might take some minutes...\n", | |
"I0804 12:45:12.282366 135447274422272 train.py:257] Initial compilation completed.\n", | |
"I0804 12:45:12.283156 135447274422272 train.py:273] train steps: 1, loss: 2.3088, accuracy: 6.25\n", | |
"I0804 12:45:12.332860 135447274422272 train.py:273] train steps: 2, loss: 2.2388, accuracy: 25.00\n", | |
"I0804 12:45:12.385134 135447274422272 train.py:273] train steps: 3, loss: 2.1062, accuracy: 37.50\n", | |
"I0804 12:45:12.433540 135447274422272 train.py:273] train steps: 4, loss: 2.1983, accuracy: 21.88\n", | |
"I0804 12:45:12.483796 135447274422272 train.py:273] train steps: 5, loss: 1.9627, accuracy: 37.50\n", | |
"I0804 12:45:12.534787 135447274422272 train.py:273] train steps: 6, loss: 1.7650, accuracy: 59.38\n", | |
"I0804 12:45:12.587219 135447274422272 train.py:273] train steps: 7, loss: 1.8484, accuracy: 37.50\n", | |
"I0804 12:45:12.652689 135447274422272 train.py:273] train steps: 8, loss: 2.4807, accuracy: 12.50\n", | |
"I0804 12:45:12.702723 135447274422272 train.py:273] train steps: 9, loss: 2.3682, accuracy: 34.38\n", | |
"I0804 12:45:12.750524 135447274422272 train.py:273] train steps: 10, loss: 1.9025, accuracy: 59.38\n", | |
"2024-08-04 12:45:12.848000: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n", | |
"2024-08-04 12:45:12.848195: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Gradient accumulationを行う場合\n", | |
"- Batch size : 16\n", | |
"- Gradient accumulation steps : 2\n", | |
"\n", | |
"Batch size 32と同じ効果?" | |
], | |
"metadata": { | |
"id": "ERTDkiLHaGfz" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!python main.py \\\n", | |
" --config config/default.py \\\n", | |
" --config.batch_size=16 \\\n", | |
" --config.gradient_accumulation_steps=2 \\\n", | |
" --config.num_train_steps=20 \\\n", | |
" --config.log_every_steps=1" | |
], | |
"metadata": { | |
"id": "S0VdXFaWHxb4", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "0161fc6b-36c9-42ba-d33f-4f71c5846230" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2024-08-04 12:45:16.368119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", | |
"2024-08-04 12:45:16.418832: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", | |
"2024-08-04 12:45:16.430451: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", | |
"2024-08-04 12:45:18.690226: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", | |
"I0804 12:45:21.961287 137594897182720 xla_bridge.py:863] Unable to initialize backend 'cuda': jaxlib/cuda/versions_helpers.cc:98: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA libraries.\n", | |
"I0804 12:45:21.961785 137594897182720 xla_bridge.py:863] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: \"rocm\". Available platform names are: CUDA\n", | |
"I0804 12:45:21.962785 137594897182720 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", | |
"I0804 12:45:21.964201 137594897182720 main.py:32] JAX process: 0 / 1\n", | |
"I0804 12:45:21.964344 137594897182720 main.py:33] JAX local devices: [CpuDevice(id=0)]\n", | |
"I0804 12:45:21.964854 137594897182720 local.py:45] Setting task status: process_index: 0, process_count: 1\n", | |
"I0804 12:45:22.614171 137594897182720 dataset_info.py:617] Load dataset info from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:22.618038 137594897182720 dataset_info.py:709] For 'mnist/3.0.1': fields info.[citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.\n", | |
"I0804 12:45:22.618571 137594897182720 dataset_builder.py:579] Reusing dataset mnist (/workdir/tensorflow_datasets/mnist/3.0.1)\n", | |
"I0804 12:45:22.620046 137594897182720 reader.py:261] Creating a tf.data.Dataset reading 1 files located in folders: /workdir/tensorflow_datasets/mnist/3.0.1.\n", | |
"I0804 12:45:22.745243 137594897182720 logging_logger.py:49] Constructing tf.data.Dataset mnist for split train[0:60000], from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:22.970062 137594897182720 reader.py:261] Creating a tf.data.Dataset reading 1 files located in folders: /workdir/tensorflow_datasets/mnist/3.0.1.\n", | |
"I0804 12:45:23.065295 137594897182720 logging_logger.py:49] Constructing tf.data.Dataset mnist for split test[0:10000], from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:23.162284 137594897182720 train.py:221] Steps per epech : 3750, Step per eval : 625, Num steps : 20\n", | |
"I0804 12:45:25.057252 137594897182720 train.py:159] Batch size: 16, Gradient accumulation steps : 2\n", | |
"I0804 12:45:25.384444 137594897182720 train.py:249] Initial compilation, this might take some minutes...\n", | |
"I0804 12:45:27.330970 137594897182720 train.py:257] Initial compilation completed.\n", | |
"I0804 12:45:27.331912 137594897182720 train.py:273] train steps: 1, loss: 2.3106, accuracy: 6.25\n", | |
"I0804 12:45:27.387209 137594897182720 train.py:273] train steps: 2, loss: 2.3070, accuracy: 6.25\n", | |
"I0804 12:45:27.436834 137594897182720 train.py:273] train steps: 3, loss: 2.2221, accuracy: 31.25\n", | |
"I0804 12:45:27.489237 137594897182720 train.py:273] train steps: 4, loss: 2.2554, accuracy: 18.75\n", | |
"I0804 12:45:27.540129 137594897182720 train.py:273] train steps: 5, loss: 2.1352, accuracy: 43.75\n", | |
"I0804 12:45:27.602766 137594897182720 train.py:273] train steps: 6, loss: 2.0772, accuracy: 31.25\n", | |
"I0804 12:45:27.653262 137594897182720 train.py:273] train steps: 7, loss: 2.0538, accuracy: 31.25\n", | |
"I0804 12:45:27.703630 137594897182720 train.py:273] train steps: 8, loss: 2.3428, accuracy: 12.50\n", | |
"I0804 12:45:27.746881 137594897182720 train.py:273] train steps: 9, loss: 1.9367, accuracy: 56.25\n", | |
"I0804 12:45:27.801099 137594897182720 train.py:273] train steps: 10, loss: 1.9887, accuracy: 18.75\n", | |
"I0804 12:45:27.850675 137594897182720 train.py:273] train steps: 11, loss: 1.7135, accuracy: 75.00\n", | |
"I0804 12:45:27.902988 137594897182720 train.py:273] train steps: 12, loss: 1.8164, accuracy: 43.75\n", | |
"I0804 12:45:27.953178 137594897182720 train.py:273] train steps: 13, loss: 1.8727, accuracy: 31.25\n", | |
"I0804 12:45:28.002290 137594897182720 train.py:273] train steps: 14, loss: 1.8241, accuracy: 43.75\n", | |
"I0804 12:45:28.048745 137594897182720 train.py:273] train steps: 15, loss: 2.5509, accuracy: 18.75\n", | |
"I0804 12:45:28.103076 137594897182720 train.py:273] train steps: 16, loss: 2.4105, accuracy: 6.25\n", | |
"I0804 12:45:28.150916 137594897182720 train.py:273] train steps: 17, loss: 2.2726, accuracy: 31.25\n", | |
"I0804 12:45:28.203733 137594897182720 train.py:273] train steps: 18, loss: 2.4638, accuracy: 37.50\n", | |
"I0804 12:45:28.252958 137594897182720 train.py:273] train steps: 19, loss: 1.9791, accuracy: 50.00\n", | |
"I0804 12:45:28.291021 137594897182720 train.py:273] train steps: 20, loss: 1.8259, accuracy: 68.75\n", | |
"2024-08-04 12:45:28.451456: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n", | |
"2024-08-04 12:45:28.451668: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"- Batch size : 16\n", | |
"- Gradient accumulation steps : 1" | |
], | |
"metadata": { | |
"id": "uTrgDBGxaPP8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!python main.py \\\n", | |
" --config config/default.py \\\n", | |
" --config.batch_size=16 \\\n", | |
" --config.gradient_accumulation_steps=1 \\\n", | |
" --config.num_train_steps=20 \\\n", | |
" --config.log_every_steps=1" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "AUv9AsSVaM7j", | |
"outputId": "7cca25b1-8dbb-4c25-c351-f9196869dfdf" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2024-08-04 12:45:31.804240: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", | |
"2024-08-04 12:45:31.847392: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", | |
"2024-08-04 12:45:31.860748: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", | |
"2024-08-04 12:45:33.957826: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", | |
"I0804 12:45:36.968580 138849378660352 xla_bridge.py:863] Unable to initialize backend 'cuda': jaxlib/cuda/versions_helpers.cc:98: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA libraries.\n", | |
"I0804 12:45:36.969110 138849378660352 xla_bridge.py:863] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: \"rocm\". Available platform names are: CUDA\n", | |
"I0804 12:45:36.970186 138849378660352 xla_bridge.py:863] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", | |
"I0804 12:45:36.971597 138849378660352 main.py:32] JAX process: 0 / 1\n", | |
"I0804 12:45:36.971764 138849378660352 main.py:33] JAX local devices: [CpuDevice(id=0)]\n", | |
"I0804 12:45:36.972244 138849378660352 local.py:45] Setting task status: process_index: 0, process_count: 1\n", | |
"I0804 12:45:37.590896 138849378660352 dataset_info.py:617] Load dataset info from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:37.594401 138849378660352 dataset_info.py:709] For 'mnist/3.0.1': fields info.[citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.\n", | |
"I0804 12:45:37.594839 138849378660352 dataset_builder.py:579] Reusing dataset mnist (/workdir/tensorflow_datasets/mnist/3.0.1)\n", | |
"I0804 12:45:37.596096 138849378660352 reader.py:261] Creating a tf.data.Dataset reading 1 files located in folders: /workdir/tensorflow_datasets/mnist/3.0.1.\n", | |
"I0804 12:45:37.714124 138849378660352 logging_logger.py:49] Constructing tf.data.Dataset mnist for split train[0:60000], from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:37.841467 138849378660352 reader.py:261] Creating a tf.data.Dataset reading 1 files located in folders: /workdir/tensorflow_datasets/mnist/3.0.1.\n", | |
"I0804 12:45:37.917054 138849378660352 logging_logger.py:49] Constructing tf.data.Dataset mnist for split test[0:10000], from /workdir/tensorflow_datasets/mnist/3.0.1\n", | |
"I0804 12:45:37.996205 138849378660352 train.py:221] Steps per epech : 3750, Step per eval : 625, Num steps : 20\n", | |
"I0804 12:45:39.822134 138849378660352 train.py:159] Batch size: 16, Gradient accumulation steps : 1\n", | |
"I0804 12:45:40.125809 138849378660352 train.py:249] Initial compilation, this might take some minutes...\n", | |
"I0804 12:45:41.664546 138849378660352 train.py:257] Initial compilation completed.\n", | |
"I0804 12:45:41.665370 138849378660352 train.py:273] train steps: 1, loss: 2.3106, accuracy: 6.25\n", | |
"I0804 12:45:41.702202 138849378660352 train.py:273] train steps: 2, loss: 2.2911, accuracy: 18.75\n", | |
"I0804 12:45:41.739797 138849378660352 train.py:273] train steps: 3, loss: 2.1582, accuracy: 18.75\n", | |
"I0804 12:45:41.774683 138849378660352 train.py:273] train steps: 4, loss: 2.2414, accuracy: 31.25\n", | |
"I0804 12:45:41.811264 138849378660352 train.py:273] train steps: 5, loss: 2.0050, accuracy: 25.00\n", | |
"I0804 12:45:41.846961 138849378660352 train.py:273] train steps: 6, loss: 1.9098, accuracy: 37.50\n", | |
"I0804 12:45:41.887347 138849378660352 train.py:273] train steps: 7, loss: 1.8347, accuracy: 37.50\n", | |
"I0804 12:45:41.926635 138849378660352 train.py:273] train steps: 8, loss: 1.9739, accuracy: 18.75\n", | |
"I0804 12:45:41.962526 138849378660352 train.py:273] train steps: 9, loss: 1.5706, accuracy: 56.25\n", | |
"I0804 12:45:41.998437 138849378660352 train.py:273] train steps: 10, loss: 1.1444, accuracy: 62.50\n", | |
"I0804 12:45:42.032355 138849378660352 train.py:273] train steps: 11, loss: 0.7948, accuracy: 81.25\n", | |
"I0804 12:45:42.068146 138849378660352 train.py:273] train steps: 12, loss: 2.4738, accuracy: 50.00\n", | |
"I0804 12:45:42.109282 138849378660352 train.py:273] train steps: 13, loss: 8.3565, accuracy: 12.50\n", | |
"I0804 12:45:42.144370 138849378660352 train.py:273] train steps: 14, loss: 2.3201, accuracy: 12.50\n", | |
"I0804 12:45:42.181330 138849378660352 train.py:273] train steps: 15, loss: 2.3386, accuracy: 6.25\n", | |
"I0804 12:45:42.217348 138849378660352 train.py:273] train steps: 16, loss: 2.2562, accuracy: 18.75\n", | |
"I0804 12:45:42.255419 138849378660352 train.py:273] train steps: 17, loss: 2.3640, accuracy: 12.50\n", | |
"I0804 12:45:42.288444 138849378660352 train.py:273] train steps: 18, loss: 2.2612, accuracy: 12.50\n", | |
"I0804 12:45:42.328277 138849378660352 train.py:273] train steps: 19, loss: 2.1875, accuracy: 43.75\n", | |
"I0804 12:45:42.362622 138849378660352 train.py:273] train steps: 20, loss: 2.3880, accuracy: 18.75\n", | |
"2024-08-04 12:45:42.492202: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n", | |
"2024-08-04 12:45:42.492362: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment