Skip to content

Instantly share code, notes, and snippets.

@crackcomm
Created October 22, 2021 06:28
Show Gist options
  • Save crackcomm/978374641b7fe678d05aeb808500250a to your computer and use it in GitHub Desktop.
Save crackcomm/978374641b7fe678d05aeb808500250a to your computer and use it in GitHub Desktop.
diff --git a/src/config/discover.ml b/src/config/discover.ml
index 290a3d2..37b2ddc 100644
--- a/src/config/discover.ml
+++ b/src/config/discover.ml
@@ -27,6 +27,7 @@ let torch_flags () =
[ Printf.sprintf "-Wl,-rpath,%s" lib_dir
; Printf.sprintf "-L%s" lib_dir
; "-lc10"
+ ; "-lc10_cuda"
; "-ltorch_cpu"
; "-ltorch"
]
@@ -71,15 +72,15 @@ let torch_flags () =
| None -> empty_flags))
let libcuda_flags ~lcuda ~lnvrtc =
- let cudadir = "/usr/local/cuda/lib64" in
+ let cudadir = "/usr/local/cuda" in
if file_exists cudadir && Caml.Sys.is_directory cudadir
then (
let libs =
- [ Printf.sprintf "-Wl,-rpath,%s" cudadir; Printf.sprintf "-L%s" cudadir ]
+ [ Printf.sprintf "-Wl,-rpath,%s/lib64" cudadir; Printf.sprintf "-L%s/lib64" cudadir ]
in
let libs = if lcuda then libs @ [ "-lcudart" ] else libs in
let libs = if lnvrtc then libs @ [ "-lnvrtc" ] else libs in
- { C.Pkg_config.cflags = []; libs })
+ { C.Pkg_config.cflags = [ Printf.sprintf "-I%s/include" cudadir ]; libs })
else empty_flags
let () =
diff --git a/src/stubs/torch_bindings.ml b/src/stubs/torch_bindings.ml
index 068b81b..13517f4 100644
--- a/src/stubs/torch_bindings.ml
+++ b/src/stubs/torch_bindings.ml
@@ -195,6 +195,7 @@ module C (F : Cstubs.FOREIGN) = struct
let is_available = foreign "atc_cuda_is_available" (void @-> returning int)
let cudnn_is_available = foreign "atc_cudnn_is_available" (void @-> returning int)
let set_benchmark_cudnn = foreign "atc_set_benchmark_cudnn" (int @-> returning void)
+ let empty_cache = foreign "atc_cuda_empty_cache" (void @-> returning void)
end
module Ivalue = struct
diff --git a/src/torch/cuda.mli b/src/torch/cuda.mli
index 06adb7e..f88412f 100644
--- a/src/torch/cuda.mli
+++ b/src/torch/cuda.mli
@@ -2,3 +2,4 @@ val device_count : unit -> int
val is_available : unit -> bool
val cudnn_is_available : unit -> bool
val set_benchmark_cudnn : bool -> unit
+val empty_cache : unit -> unit
diff --git a/src/wrapper/torch_api.cpp b/src/wrapper/torch_api.cpp
index 84f6ccc..2cdf21a 100644
--- a/src/wrapper/torch_api.cpp
+++ b/src/wrapper/torch_api.cpp
@@ -1,6 +1,7 @@
#include<torch/csrc/autograd/engine.h>
#include<torch/torch.h>
#include<ATen/autocast_mode.h>
+#include<c10/cuda/CUDACachingAllocator.h>
#include<torch/script.h>
#include<vector>
#include<caml/fail.h>
@@ -572,6 +573,10 @@ void atc_set_benchmark_cudnn(int b) {
at::globalContext().setBenchmarkCuDNN(b);
}
+void atc_cuda_empty_cache() {
+ c10::cuda::CUDACachingAllocator::emptyCache();
+}
+
module atm_load(char *filename) {
PROTECT(
return new torch::jit::script::Module(torch::jit::load(filename));
diff --git a/src/wrapper/torch_api.h b/src/wrapper/torch_api.h
index 0b60ac0..2f51c43 100644
--- a/src/wrapper/torch_api.h
+++ b/src/wrapper/torch_api.h
@@ -116,6 +116,7 @@ int atc_cuda_device_count();
int atc_cuda_is_available();
int atc_cudnn_is_available();
void atc_set_benchmark_cudnn(int b);
+void atc_cuda_empty_cache();
module atm_load(char *);
tensor atm_forward(module, tensor *tensors, int ntensors);
diff --git a/src/wrapper/wrapper.ml b/src/wrapper/wrapper.ml
index 5abba3f..f14e8a0 100644
--- a/src/wrapper/wrapper.ml
+++ b/src/wrapper/wrapper.ml
@@ -313,6 +313,7 @@ module Cuda = struct
let is_available () = is_available () <> 0
let cudnn_is_available () = cudnn_is_available () <> 0
let set_benchmark_cudnn b = set_benchmark_cudnn (if b then 1 else 0)
+ let empty_cache = empty_cache
end
module Ivalue = struct
diff --git a/src/wrapper/wrapper.mli b/src/wrapper/wrapper.mli
index 9491f92..ff9090e 100644
--- a/src/wrapper/wrapper.mli
+++ b/src/wrapper/wrapper.mli
@@ -102,6 +102,7 @@ module Cuda : sig
val is_available : unit -> bool
val cudnn_is_available : unit -> bool
val set_benchmark_cudnn : bool -> unit
+ val empty_cache : unit -> unit
end
module Ivalue : sig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment