Last active
December 16, 2015 17:51
-
-
Save kevinrobinson/16ef10d9a20e32b510cc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MatMulOp : public OpKernel { | |
public: | |
explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { | |
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); | |
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); | |
} | |
void Compute(OpKernelContext* ctx) override { | |
const Tensor& a = ctx->input(0); | |
const Tensor& b = ctx->input(1); | |
// ... Validated the dimensions of the two matrices ... | |
// ... support arguments for transposing the matrices ... | |
// ... some short-circuit optimizations ... | |
// ... allocate the output tensor... | |
LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out); | |
} | |
private: | |
bool transpose_a_; | |
bool transpose_b_; | |
}; | |
// The LaunchMatMulCPU::launch implementation has a few layers of templates, but ultimately ends | |
// up calling a specialization of MatMul, which uses Eigen to compute: out = in0 * in1 on on device "d". | |
void MatMul(const Device& d, Out out, In0 in0, In1 in1, | |
const DimPair& dim_pair) { | |
out.device(d) = in0.contract(in1, dim_pair); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment