Custom Datasets involve adding a new operation.
Created
November 17, 2017 20:03
-
-
Save asimshankar/f11e79a5b7947e716ae2242387162ba2 to your computer and use it in GitHub Desktop.
TensorFlow Custom Datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PATHS=$(python -c "import tensorflow as tf; print('-I{} -I{}/external/nsync/public -L{}'.format(tf.sysconfig.get_include(), tf.sysconfig.get_include(), tf.sysconfig.get_lib()))") | |
# Unfortunately, not all header files are currently included in the PIP package yet. So for now, | |
# clone the TensorFlow repository and switch to the appropriate branch for some additional files. | |
# Let's say that is in /tmp/tensorflow_src | |
g++ -shared -fPIC ${PATHS} -I/tensorflow -std=c++11 dataset.cc -ltensorflow_framework -olibmydataset.so |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <string> | |
#include <vector> | |
#include "tensorflow/core/framework/common_shape_fns.h" | |
#include "tensorflow/core/framework/op.h" | |
#include "tensorflow/core/framework/partial_tensor_shape.h" | |
#include "tensorflow/core/framework/tensor.h" | |
#include "tensorflow/core/kernels/dataset.h" | |
using std::string; | |
using tensorflow::Status; | |
using tensorflow::Tensor; | |
REGISTER_OP("MyDataset") | |
.Input("value: int32") | |
.Output("handle: variant") | |
.SetIsStateful() | |
.SetShapeFn(tensorflow::shape_inference::ScalarShape) | |
.Doc(R"doc( | |
Silly dataset that produces 'value' once. | |
)doc"); | |
class MyDatasetOp : public tensorflow::DatasetOpKernel { | |
public: | |
explicit MyDatasetOp(tensorflow::OpKernelConstruction* ctx) | |
: tensorflow::DatasetOpKernel(ctx) {} | |
void MakeDataset(tensorflow::OpKernelContext* ctx, | |
tensorflow::DatasetBase** output) override { | |
tensorflow::OpInputList inputs; | |
OP_REQUIRES_OK(ctx, ctx->input_list("value", &inputs)); | |
std::vector<Tensor> components; | |
components.reserve(inputs.size()); | |
for (const Tensor& t : inputs) { | |
components.push_back(t); | |
} | |
*output = new Dataset(std::move(components)); | |
} | |
private: | |
class Dataset : public tensorflow::DatasetBase { | |
public: | |
explicit Dataset(std::vector<Tensor> tensors) | |
: tensors_(std::move(tensors)) { | |
for (const Tensor& t : tensors_) { | |
dtypes_.push_back(t.dtype()); | |
shapes_.emplace_back(t.shape().dim_sizes()); | |
} | |
} | |
std::unique_ptr<tensorflow::IteratorBase> MakeIterator( | |
const string& prefix) const override { | |
return std::unique_ptr<tensorflow::IteratorBase>( | |
new Iterator({this, prefix + "::MyDataset"})); | |
} | |
const tensorflow::DataTypeVector& output_dtypes() const override { | |
return dtypes_; | |
} | |
const std::vector<tensorflow::PartialTensorShape>& output_shapes() | |
const override { | |
return shapes_; | |
} | |
string DebugString() override { return "MyDatasetOp::Dataset"; } | |
private: | |
class Iterator : public tensorflow::DatasetIterator<Dataset> { | |
public: | |
explicit Iterator(const Params& params) | |
: tensorflow::DatasetIterator<Dataset>(params), produced_(false) {} | |
Status GetNextInternal(tensorflow::IteratorContext* ctx, | |
std::vector<Tensor>* out_tensors, | |
bool* end_of_sequence) override { | |
tensorflow::mutex_lock l(mu_); | |
if (!produced_) { | |
*out_tensors = dataset()->tensors_; | |
produced_ = true; | |
*end_of_sequence = false; | |
return Status::OK(); | |
} else { | |
*end_of_sequence = true; | |
return Status::OK(); | |
} | |
} | |
private: | |
tensorflow::mutex mu_; | |
bool produced_ GUARDED_BY(mu_); | |
}; | |
const std::vector<Tensor> tensors_; | |
tensorflow::DataTypeVector dtypes_; | |
std::vector<tensorflow::PartialTensorShape> shapes_; | |
}; | |
}; | |
namespace tensorflow { | |
REGISTER_KERNEL_BUILDER(Name("MyDataset").Device(DEVICE_CPU), MyDatasetOp); | |
} // namespace tensorflow |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment