Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active June 30, 2021 14:13
Show Gist options
  • Save maedoc/37202e3e1b3dbff1b77dc6106e11917a to your computer and use it in GitHub Desktop.
Save maedoc/37202e3e1b3dbff1b77dc6106e11917a to your computer and use it in GitHub Desktop.
TF op with C++ resource

mini-example of a TF op with a custom C++ resource

There's a TF tutorial on creating a op, which covers anything written from scratch, but when you've got existing C++ classes (or anything callable from C++, really) which do the heavy lifting and are stateful, or at least require intialization/instatiation, the tutorial recommends, as a side note that ops need to be reentrant, that a ResourceMgr be used, linking to the corresponding header and no example.

That's sort of a cliff-hanger given how nice the rest of the tutorial. Two related SO questions (answered by the same person) provide a few breadcrumbs,

as it turns out reading the various headers and cc files for the LookupTable is instructive for using this part of the TF infrastructure, hence the various files in this gist. These files are a dumb example but when you paste them into the custom-op repo and hit make it builds and test passes.

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
using namespace tensorflow;
// a custom object to use for testing resources
class Counter {
int count;
public:
Counter() : count(0) {}
int get_count() { return count; }
void increment() {
printf("count is %d. ", count);
count++;
printf("now it's %d\n", count);
}
};
// a tf resource object for the counter
class CounterResource : public ResourceBase {
public:
Counter counter;
std::string DebugString() const { return "CounterResource"; }
};
// a tf op to create the counter
class CreateCounterOp : public OpKernel {
public:
// TODO add OP_REQUIRES_OK stuff
explicit CreateCounterOp(OpKernelConstruction* context) : OpKernel(context) {
// Allocate handle tensor for the counter resource
context->allocate_persistent(DT_RESOURCE, TensorShape({}), &counter_handle, nullptr);
}
// TODO add cinfo_, cf lookup_table_op.h
void Compute(OpKernelContext* context) override {
// Create our Counter instance on heap
CounterResource* counter_r = new CounterResource();
// Add it to the look up table
context->resource_manager()->Create("zero_out", "counter", counter_r);
// Create scalar accessor for handle tensor
auto h = counter_handle.AccessTensor(context)->template scalar<ResourceHandle>();
// Instantiate the resource handle with keys
h() = MakeResourceHandle<CounterResource>(context, "zero_out", "counter");
// Provide handle tensor as first output
context->set_output(0, *counter_handle.AccessTensor(context));
}
private:
PersistentTensor counter_handle;
};
REGISTER_KERNEL_BUILDER(Name("CreateCounter").Device(DEVICE_CPU), CreateCounterOp);
// a tf op to increment the counter
class IncrementCounterOp : public OpKernel {
public:
explicit IncrementCounterOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Declare ref to the counter resource
CounterResource *counter_r;
// Look it up
context->resource_manager()->Lookup("zero_out", "counter", &counter_r);
// Decrement reference count when done
core::ScopedUnref unref_me(counter_r);
// Increment our counter
counter_r->counter.increment();
} // unref_me goes out of scope, decref'ing the resource
};
REGISTER_KERNEL_BUILDER(Name("IncrementCounter").Device(DEVICE_CPU), IncrementCounterOp);
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("CreateCounter")
.Output("counter_handle: resource");
REGISTER_OP("IncrementCounter")
.Input("counter_handle: resource");
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
counter_ops = load_library.load_op_library(
resource_loader.get_path_to_datafile('_counter_ops.so'))
create_counter = counter_ops.create_counter
increment_counter = counter_ops.increment_counter
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.platform import test
from counter_ops import create_counter, increment_counter
class CounterTest(test.TestCase):
def test_increment(self):
with self.test_session():
counter = create_counter()
increment_counter(counter)
increment_counter(counter)
increment_counter(counter)
increment_counter(counter)
# outputs the following:
'''
count is 0. now it's 1
count is 1. now it's 2
count is 2. now it's 3
count is 3. now it's 4
INFO:tensorflow:time(__main__.CounterTest.test_increment): 0.01s
I0630 13:58:10.023093 140402655590144 test_util.py:2103] time(__main__.CounterTest.test_increment): 0.01s
[ OK ] CounterTest.test_increment
[ RUN ] CounterTest.test_session
[ SKIPPED ] CounterTest.test_session
[ RUN ] ZeroOutTest.testZeroOut
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment