Created
December 31, 2020 08:06
-
-
Save genzj/34b2d9813813c145239ec1d16b651e99 to your computer and use it in GitHub Desktop.
Migrate DAIN https://github.com/baowenbo/DAIN to PyTorch 1.7.1 with CUDA 10.2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/PWCNet/correlation_package_pytorch1_0/correlation.py b/PWCNet/correlation_package_pytorch1_0/correlation.py | |
index 80a8b09..fe8ab06 100644 | |
--- a/PWCNet/correlation_package_pytorch1_0/correlation.py | |
+++ b/PWCNet/correlation_package_pytorch1_0/correlation.py | |
@@ -4,19 +4,16 @@ from torch.autograd import Function | |
import correlation_cuda | |
class CorrelationFunction(Function): | |
- | |
- def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): | |
- super(CorrelationFunction, self).__init__() | |
- self.pad_size = pad_size | |
- self.kernel_size = kernel_size | |
- self.max_displacement = max_displacement | |
- self.stride1 = stride1 | |
- self.stride2 = stride2 | |
- self.corr_multiply = corr_multiply | |
- # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) | |
- | |
- def forward(self, input1, input2): | |
- self.save_for_backward(input1, input2) | |
+ @staticmethod | |
+ def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): | |
+ ctx.pad_size = pad_size | |
+ ctx.kernel_size = kernel_size | |
+ ctx.max_displacement = max_displacement | |
+ ctx.stride1 = stride1 | |
+ ctx.stride2 = stride2 | |
+ ctx.corr_multiply = corr_multiply | |
+ # ctx.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) | |
+ ctx.save_for_backward(input1, input2) | |
with torch.cuda.device_of(input1): | |
rbot1 = input1.new() | |
@@ -24,12 +21,13 @@ class CorrelationFunction(Function): | |
output = input1.new() | |
correlation_cuda.forward(input1, input2, rbot1, rbot2, output, | |
- self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) | |
+ ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply) | |
return output | |
- def backward(self, grad_output): | |
- input1, input2 = self.saved_tensors | |
+ @staticmethod | |
+ def backward(ctx, grad_output): | |
+ input1, input2 = ctx.saved_tensors | |
with torch.cuda.device_of(input1): | |
rbot1 = input1.new() | |
@@ -39,7 +37,7 @@ class CorrelationFunction(Function): | |
grad_input2 = input2.new() | |
correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, | |
- self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) | |
+ ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply) | |
return grad_input1, grad_input2 | |
@@ -56,7 +54,7 @@ class Correlation(Module): | |
def forward(self, input1, input2): | |
- result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) | |
+ result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) | |
return result | |
diff --git a/demo_MiddleBury.py b/demo_MiddleBury.py | |
index 82a6b71..21f7d63 100644 | |
--- a/demo_MiddleBury.py | |
+++ b/demo_MiddleBury.py | |
@@ -10,7 +10,7 @@ import numpy | |
import networks | |
from my_args import args | |
-from scipy.misc import imread, imsave | |
+from skimage.io import imread, imsave | |
from AverageMeter import * | |
torch.backends.cudnn.benchmark = True # to speed up the | |
diff --git a/demo_MiddleBury_slowmotion.py b/demo_MiddleBury_slowmotion.py | |
index 2bb4293..1b01bac 100644 | |
--- a/demo_MiddleBury_slowmotion.py | |
+++ b/demo_MiddleBury_slowmotion.py | |
@@ -7,7 +7,7 @@ import numpy as np | |
import numpy | |
import networks | |
from my_args import args | |
-from scipy.misc import imread, imsave | |
+from skimage.io import imread, imsave | |
from AverageMeter import * | |
import shutil | |
@@ -183,4 +183,4 @@ if DO_MiddleBurryOther: | |
count = count + 1 | |
- | |
\ No newline at end of file | |
+ | |
diff --git a/my_package/compiler_args.py b/my_package/compiler_args.py | |
index 7451c0c..de8c256 100644 | |
--- a/my_package/compiler_args.py | |
+++ b/my_package/compiler_args.py | |
@@ -4,35 +4,35 @@ nvcc_args = [ | |
# Quadro: (None) | |
# NVIDIA NVS: (None) | |
# Jetson: (None) | |
- '-gencode', 'arch=compute_37,code=sm_37', | |
+ # '-gencode', 'arch=compute_37,code=sm_37', | |
# Tesla: (None) | |
# Quadro: K1200, K620, M1200, M520, M5000M, M4000M, M3000M, M2000M, M1000M, K620M, M600M, M500M | |
# NVIDIA NVS: 810 | |
# GeForce / Titan: GTX 750 Ti, GTX 750, GTX 960M, GTX 950M, 940M, 930M, GTX 860M, GTX 850M, 840M, 830M | |
# Jetson: (None) | |
- '-gencode', 'arch=compute_50,code=sm_50', | |
+ # '-gencode', 'arch=compute_50,code=sm_50', | |
# Tesla: M60, M40 | |
# Quadro: M6000 24GB, M6000, M5000, M4000, M2000, M5500M, M2200, M620 | |
# NVIDIA NVS: (None) | |
# GeForce / Titan: GTX TITAN X, GTX 980 Ti, GTX 980, GTX 970, GTX 960, GTX 950, GTX 980, GTX 980M, GTX 970M, GTX 965M, 910M | |
# Jetson: (None) | |
- '-gencode', 'arch=compute_52,code=sm_52', | |
+ # '-gencode', 'arch=compute_52,code=sm_52', | |
# Tesla: P100 | |
# Quadro: GP100 | |
# NVIDIA: NVS: (None) | |
# GeForce / Titan: (None) | |
# Jetson: (None) | |
- '-gencode', 'arch=compute_60,code=sm_60', | |
+ # '-gencode', 'arch=compute_60,code=sm_60', | |
# Tesla: P40, P4 | |
# Quadro: P6000, P5000, P4000, P2200, P2000, P1000, P620, P600, P400, P620, P520, P5200, P4200, P3200, P5000, P4000, P3000, P2000, P1000, P600, P500 | |
# NVIDIA NVS: (None) | |
# GeForce / Titan: TITAN Xp, TITAN X, GTX 1080 Ti, GTX 1080, GTX 1070, GTX 1060, GTX 1050, GTX 1080, GTX 1070, GTX 1060 | |
# Jetson: (None) | |
- '-gencode', 'arch=compute_61,code=sm_61', | |
+ # '-gencode', 'arch=compute_61,code=sm_61', | |
# Tesla: T4 | |
# Quadro: RTX 8000, RTX 6000, RTX 5000, RTX 4000, RTX 5000, RTX 4000, RTX 3000, T2000, T1000 | |
@@ -47,4 +47,4 @@ nvcc_args = [ | |
'-w' # Ignore compiler warnings. | |
] | |
-cxx_args = ['-std=c++11', '-w'] | |
\ No newline at end of file | |
+cxx_args = ['-std=c++14', '-w'] | |
diff --git a/my_package/test_module.py b/my_package/test_module.py | |
index c1797ec..ca9ccda 100755 | |
--- a/my_package/test_module.py | |
+++ b/my_package/test_module.py | |
@@ -7,7 +7,7 @@ from torch.autograd import gradcheck | |
#from modules.InterpolationModule import InterpolationModule | |
#from modules.FilterInterpolationModule import FilterInterpolationModule | |
#from modules.FlowProjectionModule import FlowProjectionModule | |
-from my_package.DepthFlowProjection import DepthFlowProjectionModule | |
+from DepthFlowProjection import DepthFlowProjectionModule | |
#from modules.FilterInterpolationModule import AdaptiveWeightInterpolationModule | |
#from modules.SeparableConvModule import SeparableConvModule |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment