Created
August 4, 2021 03:04
-
-
Save aleozlx/74e4c81d5a80372979ebe031d3c62a72 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void test_sobel() { | |
// Configure the convolution kernel | |
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< | |
ElementInputA, LayoutInputA, | |
ElementInputB, LayoutInputB, | |
ElementOutput, LayoutOutput, | |
ElementAccumulator, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm75, | |
cutlass::gemm::GemmShape<128, 128, 32>, // Threadblock tile shape, | |
cutlass::gemm::GemmShape<64, 64, 32>, // Warp tile shape | |
cutlass::gemm::GemmShape<16, 8, 8>, // TensorCore instruction shape | |
cutlass::epilogue::thread::LinearCombination< | |
ElementOutput, // Data type of output matrix. | |
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized memory access. | |
ElementAccumulator, // Data type of accumulator | |
float>, // Data type for alpha/beta in linear combination, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, | |
2, // Number of pipelines you want to use | |
cutlass::arch::OpMultiplyAdd, | |
cutlass::conv::IteratorAlgorithm::kAnalytic | |
>::Kernel; | |
// Define the implicit GEMM with the previously defined convolution kernel | |
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>; | |
// Read the input image | |
std::vector<unsigned char> im_buildings; | |
im_buildings.resize(390*390); | |
std::ifstream fin("buildings_original.bin", std::ios::binary); | |
if (fin.is_open()) { | |
std::copy(std::istreambuf_iterator<char>(fin), std::istreambuf_iterator<char>(), std::begin(im_buildings)); | |
fin.close(); | |
} | |
// Define tensor shapes | |
cutlass::Tensor4DCoord input_size(1, 390, 390, 32), | |
filter_size(32, 3, 3, 32), | |
output_size(1, 390, 390, 32); | |
// Define the "tensors" that will be holding all the data | |
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(input_size); | |
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(filter_size); | |
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(output_size); | |
// Load the input image | |
cutlass::reference::host::TensorForEachLambda( | |
tensor_a.extent(), | |
[&](cutlass::Tensor4DCoord const & coord) { | |
const auto& view = tensor_a.host_view(); | |
int pitch = input_size.w(); | |
auto im = im_buildings.data(); | |
if (coord.n() == 0 && coord.c() == 0) { | |
const auto src_offset = coord.h() * pitch + coord.w(); | |
view.at(coord) = static_cast<cutlass::half_t>(im[src_offset]); | |
} | |
else { | |
view.at(coord) = 0.f; | |
} | |
} | |
); | |
// Define and load the Sobel filter | |
const float sobel[9] = { | |
1.f, 0.f, -1.f, | |
2.f, 0.f, -2.f, | |
1.f, 0.f, -1.f | |
}; | |
cutlass::reference::host::TensorForEachLambda( | |
tensor_b.extent(), | |
[&](cutlass::Tensor4DCoord const & coord) { | |
const auto& view = tensor_b.host_view(); | |
if (coord.n() < 2 && coord.c() == 0) { | |
switch(coord.n()) { | |
case 0: | |
{ | |
const auto src_offset = coord.h() * 3 + coord.w(); | |
view.at(coord) = static_cast<cutlass::half_t>(sobel[src_offset]); | |
} | |
break; | |
case 1: | |
{ | |
const auto src_offset = coord.w() * 3 + coord.h(); | |
view.at(coord) = static_cast<cutlass::half_t>(sobel[src_offset]); | |
} | |
break; | |
default: | |
view.at(coord) = 0.f; | |
break; | |
} | |
} | |
else { | |
view.at(coord) = 0.f; | |
} | |
} | |
); | |
// Transfer to the GPU | |
tensor_a.sync_device(); | |
tensor_b.sync_device(); | |
tensor_c.sync_device(); | |
// Describe the convolution problem sizes | |
cutlass::conv::Conv2dProblemSize problem_size( | |
input_size, | |
filter_size, | |
cutlass::Tensor4DCoord(1, 1, 1, 1), | |
cutlass::MatrixCoord(1, 1), | |
cutlass::MatrixCoord(1, 1), | |
output_size, | |
cutlass::conv::Mode::kCrossCorrelation, | |
1 // Split K dimension into 1 partitions | |
); | |
// Define the implicit GEMM parameters | |
typename ImplicitGemm::Arguments arguments{ | |
problem_size, | |
tensor_a.device_ref(), | |
tensor_b.device_ref(), | |
tensor_c.device_ref(), | |
tensor_c.device_ref(), | |
{1, 0}, | |
}; | |
// Instantiate the implicit GEMM operation | |
ImplicitGemm implicit_gemm_op; | |
// Any other device memory allocation will go in this "workspace" | |
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); | |
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | |
// Run the implicit GEMM convolution | |
Result result; | |
result.status = implicit_gemm_op.can_implement(arguments); | |
CUTLASS_CHECK(result.status); | |
result.status = implicit_gemm_op.initialize(arguments, workspace.get()); | |
CUTLASS_CHECK(result.status); | |
result.status = implicit_gemm_op(); | |
CUTLASS_CHECK(result.status); | |
// Transfer back to the host | |
tensor_c.sync_host(); | |
// Write the output raw frames out to binary files | |
for (int c_out = 0; c_out < 32; ++c_out) { | |
std::vector<unsigned char> im_buildings_out; | |
im_buildings_out.resize(390*390); | |
cutlass::reference::host::TensorForEachLambda( | |
tensor_c.extent(), | |
[&](cutlass::Tensor4DCoord const & coord) { | |
const auto& view = tensor_c.host_view(); | |
int pitch = input_size.w(); | |
auto im = im_buildings_out.data(); | |
if (coord.n() == 0 && coord.c() == c_out) { | |
const auto dst_offset = coord.h() * input_size.w() + coord.w(); | |
float pixel_val = view.at(coord); | |
if (pixel_val > 255.f) { | |
pixel_val = 255.f; | |
} | |
else if (pixel_val < 0.f) { | |
pixel_val = 0.f; | |
} | |
im[dst_offset] = static_cast<unsigned char>(pixel_val); | |
} | |
} | |
); | |
char fname_out[500]; | |
sprintf(fname_out, "buildings_filter.%02d.bin", c_out); | |
std::ofstream fout(fname_out, std::ios::binary); | |
if (fout.is_open()) { | |
std::copy(std::begin(im_buildings_out), std::end(im_buildings_out), std::ostreambuf_iterator<char>(fout)); | |
fout.close(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment