|
#include "tensorflow/core/framework/op_kernel.h" |
|
#include "tensorflow/core/framework/resource_mgr.h" |
|
|
|
using namespace tensorflow; |
|
|
|
// a custom object to use for testing resources |
|
class Counter { |
|
int count; |
|
public: |
|
Counter() : count(0) {} |
|
int get_count() { return count; } |
|
void increment() { |
|
printf("count is %d. ", count); |
|
count++; |
|
printf("now it's %d\n", count); |
|
} |
|
}; |
|
|
|
// a tf resource object for the counter |
|
class CounterResource : public ResourceBase { |
|
public: |
|
Counter counter; |
|
std::string DebugString() const { return "CounterResource"; } |
|
}; |
|
|
|
// a tf op to create the counter |
|
class CreateCounterOp : public OpKernel { |
|
public: |
|
|
|
// TODO add OP_REQUIRES_OK stuff |
|
explicit CreateCounterOp(OpKernelConstruction* context) : OpKernel(context) { |
|
// Allocate handle tensor for the counter resource |
|
context->allocate_persistent(DT_RESOURCE, TensorShape({}), &counter_handle, nullptr); |
|
} |
|
|
|
// TODO add cinfo_, cf lookup_table_op.h |
|
void Compute(OpKernelContext* context) override { |
|
// Create our Counter instance on heap |
|
CounterResource* counter_r = new CounterResource(); |
|
// Add it to the look up table |
|
context->resource_manager()->Create("zero_out", "counter", counter_r); |
|
// Create scalar accessor for handle tensor |
|
auto h = counter_handle.AccessTensor(context)->template scalar<ResourceHandle>(); |
|
// Instantiate the resource handle with keys |
|
h() = MakeResourceHandle<CounterResource>(context, "zero_out", "counter"); |
|
// Provide handle tensor as first output |
|
context->set_output(0, *counter_handle.AccessTensor(context)); |
|
} |
|
private: |
|
PersistentTensor counter_handle; |
|
}; |
|
|
|
REGISTER_KERNEL_BUILDER(Name("CreateCounter").Device(DEVICE_CPU), CreateCounterOp); |
|
|
|
// a tf op to increment the counter |
|
class IncrementCounterOp : public OpKernel { |
|
public: |
|
explicit IncrementCounterOp(OpKernelConstruction* context) : OpKernel(context) {} |
|
|
|
void Compute(OpKernelContext* context) override { |
|
// Declare ref to the counter resource |
|
CounterResource *counter_r; |
|
// Look it up |
|
context->resource_manager()->Lookup("zero_out", "counter", &counter_r); |
|
// Decrement reference count when done |
|
core::ScopedUnref unref_me(counter_r); |
|
// Increment our counter |
|
counter_r->counter.increment(); |
|
} // unref_me goes out of scope, decref'ing the resource |
|
}; |
|
|
|
REGISTER_KERNEL_BUILDER(Name("IncrementCounter").Device(DEVICE_CPU), IncrementCounterOp); |