Last active
May 17, 2016 11:38
-
-
Save botev/9addde28ba6810cb23e9157a920b6ce0 to your computer and use it in GitHub Desktop.
Iterator for arrayfire
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <map> | |
#include <algorithm> | |
#include <iterator> | |
#include "arrayfire.h" | |
class AbstractDataSource{ | |
protected: | |
unsigned instance_dim; | |
long n; | |
std::map<std::string, af::array> data; | |
std::vector<std::string> current_datums; | |
public: | |
AbstractDataSource(unsigned instance_dim = 0): instance_dim(instance_dim) {}; | |
void set_datums(std::vector<std::string> datums){ | |
current_datums = datums; | |
} | |
virtual void add_data(std::string key, af::array value){ | |
if(data.size() == 0){ | |
n = value.dims(instance_dim); | |
} | |
if(value.dims(instance_dim) != n){ | |
throw 2; | |
} | |
if(data.find(key) != data.end()){ | |
throw 1; | |
} | |
data[key] = value; | |
} | |
virtual void shuffle_data(af::array order){ | |
if(instance_dim == 0){ | |
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(order, af::span, af::span, af::span);}; | |
std::for_each(data.begin(), data.end(), f); | |
} else if(instance_dim == 1){ | |
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(af::span, order, af::span, af::span);}; | |
std::for_each(data.begin(), data.end(), f); | |
} else if(instance_dim == 2){ | |
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(af::span, af::span, order, af::span);}; | |
std::for_each(data.begin(), data.end(), f); | |
} else if(instance_dim == 3){ | |
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(af::span, af::span, af::span, order);}; | |
std::for_each(data.begin(), data.end(), f); | |
} | |
} | |
virtual void random_shuffle(){ | |
af::array rand = af::randu(n); | |
af::array order; | |
af::sort(rand, order, rand, 0); | |
shuffle_data(order); | |
} | |
class iterator{ | |
private: | |
AbstractDataSource & source; | |
int batch_size; | |
bool full_batches; | |
int index; | |
std::vector<af::array::array_proxy> slice; | |
bool updated; | |
public: | |
typedef std::vector<af::array::array_proxy> value_type; | |
typedef std::vector<af::array::array_proxy>& ref_type; | |
typedef std::vector<af::array::array_proxy>* ptr_type; | |
iterator(AbstractDataSource & source, int index, | |
int batch_size, | |
bool full_batches): | |
source(source), index(index), | |
batch_size(batch_size), | |
full_batches(full_batches), updated(false) {}; | |
iterator(iterator const & ref): | |
source(ref.source), index(index), | |
batch_size(ref.batch_size), | |
full_batches(ref.full_batches), updated(false) { | |
}; | |
iterator& operator++() { | |
index+=batch_size; | |
if(index > source.n or (full_batches and index + batch_size >= source.n)){ | |
index = source.n; | |
} | |
updated = false; | |
return *this; | |
} | |
iterator operator++(int) { | |
iterator copy(*this); | |
++(*this); | |
return copy; | |
} | |
bool operator==(iterator const & ref){ | |
return &source == &ref.source and batch_size == ref.batch_size and | |
index == ref.index; | |
} | |
bool operator!=(iterator const & ref){ | |
return &source != &ref.source or batch_size != ref.batch_size or | |
index != ref.index; | |
} | |
ref_type operator*() { | |
fetch_slice(); | |
return slice; | |
} | |
ptr_type operator->(){ | |
fetch_slice(); | |
return &slice; | |
} | |
void fetch_slice(){ | |
if(not updated) { | |
slice.clear(); | |
int last = (index+batch_size-1) < source.n ? (index+batch_size-1) : source.n-1; | |
for (auto i = 0; i < source.current_datums.size(); ++i) { | |
if(source.instance_dim == 0){ | |
slice.push_back(source.data[source.current_datums[i]](af::seq(index, last), af::span, af::span, af::span)); | |
} else if(source.instance_dim == 1){ | |
slice.push_back(source.data[source.current_datums[i]](af::span, af::seq(index, last), af::span, af::span)); | |
} else if(source.instance_dim == 2){ | |
slice.push_back(source.data[source.current_datums[i]](af::span, af::span, af::seq(index, last), af::span)); | |
} else if(source.instance_dim == 3){ | |
slice.push_back(source.data[source.current_datums[i]](af::span, af::span, af::span, af::seq(index, last))); | |
} | |
} | |
updated = true; | |
} | |
} | |
}; | |
iterator begin( int batch_size, bool full_batches = false){ | |
return iterator(*this, 0, batch_size, full_batches); | |
} | |
iterator end(int batch_size){ | |
return iterator(*this, n, batch_size, false); | |
} | |
void print(){ | |
for(auto i = data.begin(); i != data.end(); ++i){ | |
std::cout << i->first << std::endl; | |
af_print(i->second); | |
} | |
} | |
}; | |
int main() | |
{ | |
auto source = AbstractDataSource(1); | |
source.add_data("d1", af::randu(5, 7)); | |
source.add_data("d2", af::randu(3, 7)); | |
source.add_data("d3", af::randu(1, 7)); | |
source.print(); | |
source.random_shuffle(); | |
source.print(); | |
source.set_datums({"d2", "d3"}); | |
for(auto i = source.begin(2, true); i != source.end(2); ++i){ | |
std::cout << "Iteration " << std::endl; | |
for(auto j = 0; j < i->size(); ++j){ | |
af_print((*i)[j]); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment