Last active
January 30, 2025 20:36
-
-
Save jvmncs/f0f32dcbb38e7bccd5fb076f0ae840ee to your computer and use it in GitHub Desktop.
uv-friendly devShell for CUDA-enabled PyTorch/Jax
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rename this file to flake.nix, put it next to a "use flake" .envrc file in your project folder | |
{ | |
inputs = { | |
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; | |
}; | |
outputs = { self, nixpkgs }: | |
let | |
system = "x86_64-linux"; # adjust if needed | |
pkgs = import nixpkgs { | |
inherit system; | |
config.allowUnfree = true; | |
}; | |
cuda = pkgs.cudaPackages_11_8; | |
in { | |
devShells.${system}.default = pkgs.mkShell { | |
nativeBuildInputs = with cuda; [ | |
cudatoolkit | |
cuda_nvrtc | |
cuda_cupti | |
cudnn | |
]; | |
shellHook = '' | |
export CUDA_PATH=${cuda.cudatoolkit} | |
export CUDA_HOME=${cuda.cudatoolkit} | |
export LD_LIBRARY_PATH=${cuda.cudatoolkit}/lib:${cuda.cudnn}/lib:${cuda.cuda_nvrtc}/lib:${cuda.cuda_cupti}/lib:$LD_LIBRARY_PATH | |
export XLA_FLAGS="--xla_gpu_cuda_data_dir=${cuda.cudatoolkit}" | |
export LD_LIBRARY_PATH=/run/opengl-driver/lib:$LD_LIBRARY_PATH | |
export LD_LIBRARY_PATH=/run/opengl-driver/lib64:$LD_LIBRARY_PATH | |
echo "CUDA ${cuda.cudatoolkit.version} environment activated" | |
''; | |
}; | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment