Created
October 22, 2021 06:28
-
-
Save crackcomm/978374641b7fe678d05aeb808500250a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/config/discover.ml b/src/config/discover.ml | |
index 290a3d2..37b2ddc 100644 | |
--- a/src/config/discover.ml | |
+++ b/src/config/discover.ml | |
@@ -27,6 +27,7 @@ let torch_flags () = | |
[ Printf.sprintf "-Wl,-rpath,%s" lib_dir | |
; Printf.sprintf "-L%s" lib_dir | |
; "-lc10" | |
+ ; "-lc10_cuda" | |
; "-ltorch_cpu" | |
; "-ltorch" | |
] | |
@@ -71,15 +72,15 @@ let torch_flags () = | |
| None -> empty_flags)) | |
let libcuda_flags ~lcuda ~lnvrtc = | |
- let cudadir = "/usr/local/cuda/lib64" in | |
+ let cudadir = "/usr/local/cuda" in | |
if file_exists cudadir && Caml.Sys.is_directory cudadir | |
then ( | |
let libs = | |
- [ Printf.sprintf "-Wl,-rpath,%s" cudadir; Printf.sprintf "-L%s" cudadir ] | |
+ [ Printf.sprintf "-Wl,-rpath,%s/lib64" cudadir; Printf.sprintf "-L%s/lib64" cudadir ] | |
in | |
let libs = if lcuda then libs @ [ "-lcudart" ] else libs in | |
let libs = if lnvrtc then libs @ [ "-lnvrtc" ] else libs in | |
- { C.Pkg_config.cflags = []; libs }) | |
+ { C.Pkg_config.cflags = [ Printf.sprintf "-I%s/include" cudadir ]; libs }) | |
else empty_flags | |
let () = | |
diff --git a/src/stubs/torch_bindings.ml b/src/stubs/torch_bindings.ml | |
index 068b81b..13517f4 100644 | |
--- a/src/stubs/torch_bindings.ml | |
+++ b/src/stubs/torch_bindings.ml | |
@@ -195,6 +195,7 @@ module C (F : Cstubs.FOREIGN) = struct | |
let is_available = foreign "atc_cuda_is_available" (void @-> returning int) | |
let cudnn_is_available = foreign "atc_cudnn_is_available" (void @-> returning int) | |
let set_benchmark_cudnn = foreign "atc_set_benchmark_cudnn" (int @-> returning void) | |
+ let empty_cache = foreign "atc_cuda_empty_cache" (void @-> returning void) | |
end | |
module Ivalue = struct | |
diff --git a/src/torch/cuda.mli b/src/torch/cuda.mli | |
index 06adb7e..f88412f 100644 | |
--- a/src/torch/cuda.mli | |
+++ b/src/torch/cuda.mli | |
@@ -2,3 +2,4 @@ val device_count : unit -> int | |
val is_available : unit -> bool | |
val cudnn_is_available : unit -> bool | |
val set_benchmark_cudnn : bool -> unit | |
+val empty_cache : unit -> unit | |
diff --git a/src/wrapper/torch_api.cpp b/src/wrapper/torch_api.cpp | |
index 84f6ccc..2cdf21a 100644 | |
--- a/src/wrapper/torch_api.cpp | |
+++ b/src/wrapper/torch_api.cpp | |
@@ -1,6 +1,7 @@ | |
#include<torch/csrc/autograd/engine.h> | |
#include<torch/torch.h> | |
#include<ATen/autocast_mode.h> | |
+#include<c10/cuda/CUDACachingAllocator.h> | |
#include<torch/script.h> | |
#include<vector> | |
#include<caml/fail.h> | |
@@ -572,6 +573,10 @@ void atc_set_benchmark_cudnn(int b) { | |
at::globalContext().setBenchmarkCuDNN(b); | |
} | |
+void atc_cuda_empty_cache() { | |
+ c10::cuda::CUDACachingAllocator::emptyCache(); | |
+} | |
+ | |
module atm_load(char *filename) { | |
PROTECT( | |
return new torch::jit::script::Module(torch::jit::load(filename)); | |
diff --git a/src/wrapper/torch_api.h b/src/wrapper/torch_api.h | |
index 0b60ac0..2f51c43 100644 | |
--- a/src/wrapper/torch_api.h | |
+++ b/src/wrapper/torch_api.h | |
@@ -116,6 +116,7 @@ int atc_cuda_device_count(); | |
int atc_cuda_is_available(); | |
int atc_cudnn_is_available(); | |
void atc_set_benchmark_cudnn(int b); | |
+void atc_cuda_empty_cache(); | |
module atm_load(char *); | |
tensor atm_forward(module, tensor *tensors, int ntensors); | |
diff --git a/src/wrapper/wrapper.ml b/src/wrapper/wrapper.ml | |
index 5abba3f..f14e8a0 100644 | |
--- a/src/wrapper/wrapper.ml | |
+++ b/src/wrapper/wrapper.ml | |
@@ -313,6 +313,7 @@ module Cuda = struct | |
let is_available () = is_available () <> 0 | |
let cudnn_is_available () = cudnn_is_available () <> 0 | |
let set_benchmark_cudnn b = set_benchmark_cudnn (if b then 1 else 0) | |
+ let empty_cache = empty_cache | |
end | |
module Ivalue = struct | |
diff --git a/src/wrapper/wrapper.mli b/src/wrapper/wrapper.mli | |
index 9491f92..ff9090e 100644 | |
--- a/src/wrapper/wrapper.mli | |
+++ b/src/wrapper/wrapper.mli | |
@@ -102,6 +102,7 @@ module Cuda : sig | |
val is_available : unit -> bool | |
val cudnn_is_available : unit -> bool | |
val set_benchmark_cudnn : bool -> unit | |
+ val empty_cache : unit -> unit | |
end | |
module Ivalue : sig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment