Created
September 19, 2018 00:46
-
-
Save goldsborough/4fbae0d83c49a26faf09cf8a1b32ef26 to your computer and use it in GitHub Desktop.
Stream/Random policy data loader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace torch { | |
namespace data { | |
template <typename D = torch::Tensor, typename L = torch::Tensor> | |
struct Example { | |
D data; | |
L label; | |
}; | |
template <typename D> | |
struct Example<D, void> { | |
D data; | |
}; | |
namespace datasets { | |
// can this just be an enum class? | |
namespace access_policy { | |
// Allows next_batch(size_t batch_size) | |
struct Stream {}; | |
// Allows next_batch(ArrayRef<size_t> indices) | |
struct Random : Stream {}; | |
} // namespace access_policy | |
template <typename S, typename T> | |
struct Map; | |
// Trait class | |
template < | |
typename S, | |
typename B = std::vector<Example<>>, | |
typename A = access_policy::Random> | |
struct Dataset { | |
using Self = S; | |
using BatchType = B; | |
using AccessPolicy = A; | |
template <typename TransformType, typename... Args> | |
Map<Self, TransformType> map(Args&&... args) &&; | |
}; | |
// Map | |
template <typename S, typename T, typename AccessPolicy> | |
struct MapBase : Dataset<Map<S, T>, typename T::OutputType, typename S::AccessPolicy> { | |
MapBase(S&& dataset, T&& transform) | |
: dataset(std::move(dataset)), transform(std::move(transform)) {} | |
S dataset; | |
T transform; | |
}; | |
template <typename S, typename T, typename AccessPolicy> | |
struct MapImpl; | |
template <typename S, typename T> | |
struct MapImpl<S, T, access_policy::Stream> : MapBase<S, T, access_policy::Stream> { | |
using MapBase<S, T, access_policy::Stream>::MapBase; | |
typename T::OutputType next(size_t count) { | |
return this->transform(this->dataset.next(count)); | |
} | |
}; | |
template <typename S, typename T> | |
struct MapImpl<S, T, access_policy::Random> : MapBase<S, T, access_policy::Random> { | |
using MapBase<S, T, access_policy::Random>::MapBase; | |
typename T::OutputType next(std::vector<size_t>&& indices) { | |
return this->transform.apply(this->dataset.next(std::move(indices))); | |
} | |
}; | |
template<typename S, typename T> | |
struct Map : MapImpl<S, T, typename S::AccessPolicy> { | |
using MapImpl<S, T, typename S::AccessPolicy>::MapImpl; | |
}; | |
// End Map | |
template <typename S, typename B, typename A> | |
template <typename TransformType, typename... Args> | |
Map<S, TransformType> Dataset<S, B, A>::map(Args&&... args) && { | |
// static_assert( | |
// std::is_same<B, typename TransformType::InputType>::value, | |
// "Batch type of dataset does not match input type of transform"); | |
return {std::move(*static_cast<S*>(this)), TransformType(std::forward<Args>(args)...)}; | |
} | |
class MNIST : public Dataset<MNIST> { | |
public: | |
explicit MNIST(const std::string& root_path, bool train = true) : data_(100) {} | |
std::vector<Example<>> next(std::vector<size_t>&& indices) { | |
std::vector<Example<>> examples; | |
for (const auto& index : indices) { | |
examples.push_back(data_[index]); | |
} | |
return examples; | |
} | |
size_t size() const noexcept { | |
return data_.size(); | |
} | |
private: | |
std::vector<Example<>> data_; | |
}; | |
struct RowBatch { size_t count; }; | |
class HiveDataset : public Dataset<HiveDataset, RowBatch, access_policy::Stream> { | |
public: | |
HiveDataset() = default; | |
RowBatch next(size_t count) { | |
return {count}; | |
} | |
size_t size() const noexcept { | |
return 12345; | |
} | |
}; | |
} // namespace datasets | |
namespace transforms { | |
template <typename I, typename O> | |
struct Transform { | |
using InputType = I; | |
using OutputType = O; | |
}; | |
template<typename L = torch::Tensor> | |
struct TensorTransform : Transform<std::vector<Example<torch::Tensor, L>>, std::vector<Example<torch::Tensor, L>>> { | |
virtual ~TensorTransform() = default; | |
virtual Tensor apply(const Tensor& tensor) = 0; | |
Example<torch::Tensor, L> apply(Example<torch::Tensor, L>&& batch) const { | |
for (const auto& example : batch) { | |
apply(example.data); | |
} | |
return std::move(batch); | |
} | |
}; | |
struct Normalize : TensorTransform<> { | |
Normalize(double mean, double stddev) : mean(mean), stddev(stddev) {} | |
Tensor apply(const Tensor& tensor) override { | |
return (tensor - mean) / stddev; | |
} | |
template<typename L = torch::Tensor> | |
std::vector<Example<torch::Tensor, L>> apply(std::vector<Example<torch::Tensor, L>>&& batch) { | |
for (auto& example : batch) { | |
example.data = apply(example.data); | |
} | |
return std::move(batch); | |
} | |
double mean{0}; | |
double stddev{0}; | |
}; | |
} // namespace transforms | |
} // namespace data | |
} // namespace torch |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment