Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created November 1, 2017 17:41
Show Gist options
  • Save jamesr66a/57d3592f06aa31e95d068894cd9eba6f to your computer and use it in GitHub Desktop.
Save jamesr66a/57d3592f06aa31e95d068894cd9eba6f to your computer and use it in GitHub Desktop.
#ifndef CONV_TBC_OP_H
#define CONV_TBC_OP_H
#include <ATen/ATen.h>
#include <caffe2/core/context.h>
#include <caffe2/core/operator.h>
namespace caffe2 {
using at::Half;
std::function<void(void*)> deleterFor(at::Tensor t) {
// return a closure that holds a handle to t until it is called
// to keep the aten memory alive
return [t](void * ptr) mutable {
t.reset();
};
}
template <class Context>
class ConvTBCOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ConvTBCOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
pad_(OperatorBase::GetSingleArgument<int>("pad", 0)) {}
bool RunOnDevice() override {
at::Tensor input = tensorWrapping(Input(0));
at::Tensor weight = tensorWrapping(Input(1));
at::Tensor bias = tensorWrapping(Input(2));
auto input_size = input.sizes();
auto ilen = input_size[0];
auto batchSize = input_size[1];
auto inputPlanes = input_size[2];
auto kw = weight.sizes()[0];
long long olen = input_size[0] - kw + 1 + pad_ * 2;
int pad = (olen - ilen + kw - 1) / 2;
Output(0)->Resize(input_size[0] - kw + 1 + pad_ * 2, batchSize, weight.sizes()[2]);
Output(0)->template mutable_data<float>();
at::Tensor output = tensorWrapping(*Output(0));
auto output_size = output.sizes();
auto outputPlanes = output_size[2];
// input * weights + bias -> output_features
output.copy_(bias.expand(output.sizes()));
for (int k = 0; k < kw; k++) {
int iShift = std::max(0, k - pad);
int oShift = std::max(0, pad - k);
int t = std::min(ilen + pad - k, olen) - oShift;
// Note: gemm assumes column-major matrices
// input is l*m (row-major)
// weight is m*r (row-major)
// output is l*r (row-major)
if (t > 0) {
auto W = weight[k];
auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
O.addmm_(I, W);
}
}
assignTo(Output(0), output);
return true;
}
private:
TypeMeta typeMetaFor(const at::Tensor & t) {
return typeMetaFor(t.type().scalarType());
}
TypeMeta typeMetaFor(at::ScalarType st) {
#define DEFINE_CASE(ctype,aten_name,_) \
case at::k##aten_name: \
return TypeMeta::Make<ctype>();
switch(st) {
AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
default:
CAFFE_THROW("Unknown ATen Type");
}
#undef DEFINE_CASE
}
at::ScalarType atScalarTypeFor(const TypeMeta & meta) {
#define DEFINE_IF(ctype,aten_name,_) \
if(meta.Match<ctype>()) { \
return at::k##aten_name; \
}
AT_FORALL_SCALAR_TYPES(DEFINE_IF)
#undef DEFINE_IF
CAFFE_THROW("Unknown type meta"); // TODO: improve error message...
}
at::Type & typeFor(const Tensor<Context> & ten) {
return at::getType(at::Backend::CPU, atScalarTypeFor(ten.meta()));
}
const at::Tensor tensorWrapping(const Tensor<Context>& ten) {
return typeFor(ten).tensorFromBlob(const_cast<void*>(ten.raw_data()), ten.dims());
}
void assignTo(Tensor<Context>* dst, const at::Tensor& src_) {
at::Tensor src = src_.contiguous();
auto at_sizes = src.sizes();
std::vector<int64_t> dims(at_sizes.begin(), at_sizes.end());
dst->Resize(dims);
dst->ShareExternalPointer(src.data_ptr(), typeMetaFor(src), 0, deleterFor(src));
}
int pad_;
};
} // namespace caffe2
#endif /* CONV_TBC_OP_H */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment