Skip to content

Instantly share code, notes, and snippets.

@rsimd
Last active July 20, 2023 14:13
Show Gist options
  • Save rsimd/515e610fccc064c0ff0b99dfd9e134f9 to your computer and use it in GitHub Desktop.
Save rsimd/515e610fccc064c0ff0b99dfd9e134f9 to your computer and use it in GitHub Desktop.
This pyproject.toml was written for rye to install the cuda version of jax. cuda and cudnn versions are hard-corded, so you need to rewrite the "jax" and "jaxlib" lines in the dependencies column to match your environment.
[project]
name = "topics"
version = "0.1.0"
description = "Add a short description here"
dependencies = [
"numpy>=1.24.4",
"scipy>=1.10.1",
"pandas>=2.0.3",
"matplotlib>=3.7.2",
"plotly>=5.15.0",
"tqdm>=4.65.0",
"seaborn>=0.12.2",
"ipykernel>=6.24.0",
"scikit-learn>=1.3.0",
"opencv-python>=4.8.0.74",
"Pillow>=10.0.0",
"jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.13+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl",
"jax==0.4.13",
]
readme = "README.md"
requires-python = ">= 3.8"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
[tool.hatch.metadata]
allow-direct-references = true
@rsimd
Copy link
Author

rsimd commented Jul 20, 2023

jaxlibをwhlから直接追加するだけでもいけそうだが,cudnnなどのバージョンが合わずにerrorが出る場合がある.よってこれらもインストールした方が良い.

@rsimd
Copy link
Author

rsimd commented Jul 20, 2023

[project]
name = "topics"
version = "0.1.0"
description = "Add a short description here"
dependencies = [
"numpy>=1.24.4",
"scipy>=1.10.1",
"pandas>=2.0.3",
"matplotlib>=3.7.2",
"plotly>=5.15.0",
"tqdm>=4.65.0",
"seaborn>=0.12.2",
"ipykernel>=6.24.0",
"scikit-learn>=1.3.0",
"opencv-python>=4.8.0.74",
"Pillow>=10.0.0",
"nvidia-cublas-cu12>=12.2.1.16",
"nvidia-cuda-cupti-cu12>=12.2.60",
"nvidia-cuda-nvcc-cu12>=12.2.91",
"nvidia-cuda-runtime-cu12>=12.2.53",
"nvidia-cudnn-cu12>=8.9.2.26",
"nvidia-cufft-cu12>=11.0.8.15",
"nvidia-cusolver-cu12>=11.5.0.53",
"nvidia-cusparse-cu12>=12.1.1.53",
"jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.13+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl",
"jax==0.4.13",
"dm-haiku>=0.0.10",
"typing-extensions>=4.7.1",
"jaxtyping>=0.2.19",
"optax>=0.1.5",
"jmp>=0.0.4",
]
readme = "README.md"
requires-python = ">= 3.8"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = [
"jupyterlab>=4.0.3",
]

[tool.hatch.metadata]
allow-direct-references = true

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment