Install notes from BDL2022f env install on 2022-11-28
- For CUDA TOOLKIT 11.3, which can be used on older devices but may not be optimal
- set up basic conda env without any torch or jax packages, via
conda env create -f bdl_2022f.yml
- install pytorch with specific cudatoolkit version
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
- install jax with same specific cudatoolkit version
conda install -c "nvidia/label/cuda-11.3.1" cuda-nvcc
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
The first line is needed to be sure that the correct version of nvidia's cuda-nvcc tools are installed without that line, I got an error about missing "ptxas", which I solved via this helpful thread jax-ml/jax#6843 (reply in thread)