Before using the steps below, you might want to check the issue jax-ml/jax#5795. I think if this has closed then
pip install jaxlib
should work on windows
This workaround is based on cloudhan/jax-windows-builder
Pre Install NOTE If you have a prior installation of jax remove it.
- Check if jax is installed using
pip show jax
- Uninstall using
pip uninstall jax
Steps for Installation
- Download a wheel for jaxlib from https://whls.blob.core.windows.net/unstable/index.html chooose, the version of wheel is based on your verison of cuda. If you dont have cuda choose cpu wheels.
- Install using pip install <jaxlib_whl> run a powershell/cmd in the folder where the wheel file is. Copy the file to a different location if you plan to install it for a specific environment.I think that the bold value in jaxlib-0.3.2-cp37-none-win_amd64.whl represents the version of python so check the version you are using in your virtualenv. Then install as
pip install jaxlib-0.3.2-cp37-none-win_amd64.whl
- After that is done install jax as
pip install jax
Hi @Dewwww, Unfortunately I never installed it with cuda but this had worked perfectly when I used it earlier.