- Ubuntu 20.04.3 LTS
- 60GB SSD Root Drive
Note: machine-specs shouldn't really matter, provided it has a GPU (T4, Volta or newer).
This updates the system and ensures we have GCC:
sudo apt update && sudo apt upgrade
sudo apt install build-essential
Note: PyTorch bundles its own CUDA/CUDNN, but manual install is required to support Jax.
Note: it should be sufficient to just grab the deb (local or network) installer, and follow the corresponding steps on the download page. The full instructions contain more detailed system requirements (in case of issues).
This step requires you to sign up for a free NVIDIA developer account. You must be signed in to initiate the download, at which point you can copy the secure download URL. These steps are using the x86 Linux tar files:
wget -O cudnn.tar.xz "<Download Link Here>"
tar -xvf cudnn.tar.xz
Move extracted files into CUDA location (as per tar file install instructions):
sudo cp cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include
sudo cp -P cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64
sudo chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn*
Grab latest miniconda:
wget https://repo.anaconda.com/miniconda/Miniconda3-py39_4.10.3-Linux-x86_64.sh
sudo chmod +x Miniconda3-py39_4.10.3-Linux-x86_64.sh
./Miniconda3-py39_4.10.3-Linux-x86_64.sh
Refresh shell to take effect:
source ~/.bashrc
Certain LMDB datasets for singing synthesis were serialised with an old version of PyArrow. This forces us to use Python 3.7 until this dataset is recreated. Likewise, we need to set numpy 1.16.1 in this environment, to ensure backward-compatibility.
conda create --name py37 python=3.7
conda activate py37
Useful pip dependencies:
pip install --upgrade pip
pip install black pytest poetry wheel numpy==1.16.1 cython
Jupyter Lab:
conda install -c conda-forge jupyterlab
PyTorch version 1.7.1. (Note: FFT operations have breaking changes in 1.8)
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
We would like to migrate to this environment gradually. Use latest stable versions of everything. Also put Jax with CUDA support here, for experimentation.
conda deactivate
conda create --name py39 python=3.9
conda activate py39
Again, upgrade pip and install some useful starter dependencies:
pip install --upgrade pip
pip install black pytest poetry wheel numpy cython
conda install -c conda-forge jupyterlab
PyTorch Stable (>= 1.10) (Note: bundles its own CUDA):
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
Finally, install Jax with CUDA support:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
A quick test, to see that Jax can communicate with GPU. In Python shell:
import jax
key = jax.random.PRNGKey(42)
print(key.device()) # should indicate GpuDevice
_ = jax.random.split(key) # should just pass silently