Skip to content

Instantly share code, notes, and snippets.

@kevinrobinson
Last active December 16, 2015 17:51
Show Gist options
  • Save kevinrobinson/16ef10d9a20e32b510cc to your computer and use it in GitHub Desktop.
Save kevinrobinson/16ef10d9a20e32b510cc to your computer and use it in GitHub Desktop.
class MatMulOp : public OpKernel {
public:
explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
}
void Compute(OpKernelContext* ctx) override {
const Tensor& a = ctx->input(0);
const Tensor& b = ctx->input(1);
// ... Validated the dimensions of the two matrices ...
// ... support arguments for transposing the matrices ...
// ... some short-circuit optimizations ...
// ... allocate the output tensor...
LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
}
private:
bool transpose_a_;
bool transpose_b_;
};
// The LaunchMatMulCPU::launch implementation has a few layers of templates, but ultimately ends
// up calling a specialization of MatMul, which uses Eigen to compute: out = in0 * in1 on on device "d".
void MatMul(const Device& d, Out out, In0 in0, In1 in1,
const DimPair& dim_pair) {
out.device(d) = in0.contract(in1, dim_pair);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment