Skip to content

Instantly share code, notes, and snippets.

@jvmncs
Last active January 30, 2025 20:36
Show Gist options
  • Save jvmncs/f0f32dcbb38e7bccd5fb076f0ae840ee to your computer and use it in GitHub Desktop.
Save jvmncs/f0f32dcbb38e7bccd5fb076f0ae840ee to your computer and use it in GitHub Desktop.
uv-friendly devShell for CUDA-enabled PyTorch/Jax
# rename this file to flake.nix, put it next to a "use flake" .envrc file in your project folder
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
};
outputs = { self, nixpkgs }:
let
system = "x86_64-linux"; # adjust if needed
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
};
cuda = pkgs.cudaPackages_11_8;
in {
devShells.${system}.default = pkgs.mkShell {
nativeBuildInputs = with cuda; [
cudatoolkit
cuda_nvrtc
cuda_cupti
cudnn
];
shellHook = ''
export CUDA_PATH=${cuda.cudatoolkit}
export CUDA_HOME=${cuda.cudatoolkit}
export LD_LIBRARY_PATH=${cuda.cudatoolkit}/lib:${cuda.cudnn}/lib:${cuda.cuda_nvrtc}/lib:${cuda.cuda_cupti}/lib:$LD_LIBRARY_PATH
export XLA_FLAGS="--xla_gpu_cuda_data_dir=${cuda.cudatoolkit}"
export LD_LIBRARY_PATH=/run/opengl-driver/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/run/opengl-driver/lib64:$LD_LIBRARY_PATH
echo "CUDA ${cuda.cudatoolkit.version} environment activated"
'';
};
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment