Created
August 26, 2019 18:33
-
-
Save bwasti/865f9d0149f4e920598fc7341f8bba96 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<bool...> struct bool_pack; | |
template<bool... bs> | |
using all_true = std::is_same<bool_pack<bs..., true>, bool_pack<true, bs...>>; | |
template<class R, class... Ts> | |
using are_all_constructible = all_true<std::is_constructible<R, Ts>::value...>; | |
template<typename... Ts> | |
struct ivalue_constructible_tuple { | |
constexpr static bool value = are_all_constructible<c10::IValue, Ts...>::value; | |
}; | |
template<typename... Ts> | |
struct ivalue_constructible_tuple<std::tuple<Ts...>> { | |
constexpr static bool value = are_all_constructible<c10::IValue, Ts...>::value; | |
}; | |
template <typename R, typename...Args> | |
struct lazy_wrap { | |
lazy_wrap(std::function<R(Args...)> f, const char* s) : f_(f), s_(s) { | |
std::cerr << "Return type is not Tensor: " << typeid(R).name() << "\n"; | |
} | |
std::function<R(Args...)> f_; | |
const char* s_; | |
R operator()(Args... a) { | |
return f_(a...); | |
} | |
}; | |
template <typename T> | |
void debugIValue(std::ostream& out, T t) { | |
out << std::is_constructible<c10::IValue, T>::value << " " << typeid(t).name(); | |
} | |
template <typename T, typename U, typename... Args> | |
void debugIValue(std::ostream& out, T t, U u, Args... args) { | |
out << std::is_constructible<c10::IValue, T>::value << " " << typeid(t).name() << ','; | |
debugIValue(out, u, args...); | |
} | |
template<typename... Args> | |
struct lazy_wrap<Tensor, Args...> { | |
lazy_wrap(std::function<Tensor(Args...)> f, const char* s) : f_(f), s_(s) { | |
std::cerr << "Return type is Tensor\n"; | |
} | |
std::function<Tensor(Args...)> f_; | |
const char* s_; | |
template <typename T = std::tuple<Args...>, | |
typename std::enable_if<!ivalue_constructible_tuple<T>::value>::type* = nullptr> | |
Tensor operator()(Args... as) { | |
std::cerr << "Args are NOT IValue-able " << | |
ivalue_constructible_tuple<T>::value << "\n"; | |
debugIValue(std::cerr, as...); | |
return f_(as...); | |
} | |
template <typename T = std::tuple<Args...>, | |
typename std::enable_if<ivalue_constructible_tuple<T>::value>::type* = nullptr> | |
Tensor operator()(Args... as) { | |
std::cerr << "Args ARE IValue-able\n"; | |
std::vector<c10::IValue> inps = {{ c10::IValue(as)... }}; | |
return at::single_output(s_, inps); | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment