We had a previous installation of cuda 10.2 / cudnn 8.0.1 -- so following the instructions found on TensorFlow's website but modified for the fact that we have Ubuntu 20.04
$ uname -a
Linux slugpu 5.8.0-41-generic #46~20.04.1-Ubuntu SMP Mon Jan 18 17:52:23 UTC 2021 x86_64 x86_64 x86_64 GNU/Linux
Apart from the instructions - I also had to remove our older cuda installation and nvcc
binary:
sudo rm /usr/bin/nvcc
sudo rm -rf /usr/lib/cuda
which pointed at the 10.2 variations. After all is done, I also noticed that nvidia-smi
reported 11.5
while nvcc --version
reported 11.2
which is ok. This is apparently how the compatibility is defined.
$ nvidia-smi
Tue Dec 14 12:22:54 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05 Driver Version: 495.29.05 CUDA Version: 11.5 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:26:00.0 Off | N/A |
| 41% 27C P8 1W / 260W | 15MiB / 11016MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 6613 G /usr/lib/xorg/Xorg 9MiB |
| 0 N/A N/A 6945 G /usr/bin/gnome-shell 4MiB |
+-----------------------------------------------------------------------------+
and
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Feb_14_21:12:58_PST_2021
Cuda compilation tools, release 11.2, V11.2.152
Build cuda_11.2.r11.2/compiler.29618528_0
Following the instructions for the jax readme, I install the GPU option
$ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
and then checked if gpu
was used
$ python
Python 3.8.5 (default, Jul 28 2020, 12:59:40)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu
so all good to go.