Created
April 27, 2019 03:40
-
-
Save wkcn/d0a4c8d8afd5935f2a2addf4699f2b8c to your computer and use it in GitHub Desktop.
Custom Operator Design
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <initializer_list> | |
using namespace std; | |
typedef union { | |
int64_t v_int64; | |
double v_float64; | |
void* v_handle; | |
} TValue; | |
class TArgs { | |
const TValue* values; | |
const int* type_codes; | |
const int num_args; | |
}; | |
class TShape { | |
public: | |
TShape() : data_(nullptr), ndim_(0){ | |
} | |
TShape(initializer_list<int> data) : data_(nullptr), ndim_(0) { | |
*this = data; | |
} | |
TShape& operator=(initializer_list<int> data) { | |
delete[] data_; | |
ndim_ = data.size(); | |
data_ = new int[ndim_]; | |
{ | |
int i = 0; | |
for (auto p = data.begin(); p != data.end(); ++p, ++i) { | |
data_[i] = *p; | |
} | |
} | |
return *this; | |
} | |
~TShape() { | |
delete data_; | |
} | |
friend ostream& operator<<(ostream& os, const TShape& shape); | |
private: | |
int* data_; | |
int ndim_; | |
}; | |
ostream& operator<<(ostream& os, const TShape& shape) { | |
os << '('; | |
bool first = true; | |
for (int i = 0; i < shape.ndim_; ++i) { | |
if (first) first = false; | |
else os << ", "; | |
os << shape.data_[i]; | |
} | |
os << ')'; | |
return os; | |
} | |
class CustomOp { | |
public: | |
virtual ~CustomOp() {}; | |
virtual int Forward(TArgs args) = 0; | |
virtual int Backward(TArgs args) = 0; | |
virtual TShape InferShape(const TShape) = 0; | |
}; | |
class AdditionOp: public CustomOp { | |
public: | |
AdditionOp() { | |
cout << "Init AdditionOp" << endl; | |
} | |
~AdditionOp() { | |
cout << "Delete AdditionOp" << endl; | |
} | |
int Forward(TArgs args) { | |
return 0; | |
} | |
int Backward(TArgs args) { | |
return 0; | |
} | |
TShape InferShape(const TShape in_shape) { | |
return {1, 2, 3}; | |
} | |
}; | |
int main() { | |
// Deep Learning Framework creates the CustomOp Instance | |
CustomOp *op = new AdditionOp(); | |
// Infer Shape | |
TShape shape = op->InferShape(TShape({3, 9})); | |
// Deep Learning Framework deletes the CustomOp Instance | |
delete op; | |
cout << shape << endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment