Last active
November 27, 2018 01:00
-
-
Save csullivan/fde426ab32ae1f36b24bdbc70890bf51 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/ngraph/runtime/gpu/cuda_emitter.cpp b/src/ngraph/runtime/gpu/cuda_emitter.cpp | |
index a9ef0e00..09fe458b 100644 | |
--- a/src/ngraph/runtime/gpu/cuda_emitter.cpp | |
+++ b/src/ngraph/runtime/gpu/cuda_emitter.cpp | |
@@ -3096,11 +3096,11 @@ void* runtime::gpu::CUDAEmitter::get_init_reduce_val(std::string reduce_op, std: | |
{ | |
if (reduce_op == "max") | |
{ | |
- return m_host_parameters->min_by_datatype(data_type); | |
+ return TypeInfo::Get(data_type)->max_ptr(); | |
} | |
else if (reduce_op == "min") | |
{ | |
- return m_host_parameters->max_by_datatype(data_type); | |
+ return TypeInfo::Get(data_type)->lowest_ptr(); | |
} | |
else if (reduce_op == "mul" || reduce_op == "and") | |
{ | |
diff --git a/src/ngraph/runtime/gpu/gpu_host_parameters.hpp b/src/ngraph/runtime/gpu/gpu_host_parameters.hpp | |
index 04c297e4..4b93a9d8 100644 | |
--- a/src/ngraph/runtime/gpu/gpu_host_parameters.hpp | |
+++ b/src/ngraph/runtime/gpu/gpu_host_parameters.hpp | |
@@ -20,6 +20,9 @@ | |
#include <limits> | |
#include <list> | |
+#include "ngraph/except.hpp" | |
+#include "ngraph/runtime/gpu/type_info.hpp" | |
+ | |
namespace ngraph | |
{ | |
namespace runtime | |
@@ -87,124 +90,12 @@ namespace ngraph | |
return &m_uint64_t_params.back(); | |
} | |
- template <typename T> | |
- void* getMin() | |
- { | |
- return cache(std::numeric_limits<T>::has_infinity | |
- ? -std::numeric_limits<T>::infinity() | |
- : std::numeric_limits<T>::min()); | |
- } | |
- | |
- template <typename T> | |
- void* getMax() | |
- { | |
- return cache(std::numeric_limits<T>::has_infinity | |
- ? std::numeric_limits<T>::infinity() | |
- : std::numeric_limits<T>::max()); | |
- } | |
- | |
template <typename T1, typename T2> | |
void* getVal(T2 val) | |
{ | |
return cache(static_cast<T1>(val)); | |
} | |
- void* min_by_datatype(const std::string& type) | |
- { | |
- if (type == "char") | |
- { | |
- return getMin<char>(); | |
- } | |
- else if (type == "float") | |
- { | |
- return getMin<float>(); | |
- } | |
- else if (type == "double") | |
- { | |
- return getMin<double>(); | |
- } | |
- else if (type == "int8_t") | |
- { | |
- return getMin<int8_t>(); | |
- } | |
- else if (type == "int16_t") | |
- { | |
- return getMin<int16_t>(); | |
- } | |
- else if (type == "int32_t") | |
- { | |
- return getMin<int32_t>(); | |
- } | |
- else if (type == "int64_t") | |
- { | |
- return getMin<int64_t>(); | |
- } | |
- else if (type == "uint8_t") | |
- { | |
- return getMin<uint8_t>(); | |
- } | |
- else if (type == "uint16_t") | |
- { | |
- return getMin<uint16_t>(); | |
- } | |
- else if (type == "uint32_t") | |
- { | |
- return getMin<uint32_t>(); | |
- } | |
- else if (type == "uint64_t") | |
- { | |
- return getMin<uint64_t>(); | |
- } | |
- } | |
- | |
- void* max_by_datatype(const std::string& type) | |
- { | |
- if (type == "char") | |
- { | |
- return getMax<char>(); | |
- } | |
- else if (type == "float") | |
- { | |
- return getMax<float>(); | |
- } | |
- else if (type == "double") | |
- { | |
- return getMax<double>(); | |
- } | |
- else if (type == "int8_t") | |
- { | |
- return getMax<int8_t>(); | |
- } | |
- else if (type == "int16_t") | |
- { | |
- return getMax<int16_t>(); | |
- } | |
- else if (type == "int32_t") | |
- { | |
- return getMax<int32_t>(); | |
- } | |
- else if (type == "int64_t") | |
- { | |
- return getMax<int64_t>(); | |
- } | |
- else if (type == "uint8_t") | |
- { | |
- return getMax<uint8_t>(); | |
- } | |
- else if (type == "uint16_t") | |
- { | |
- return getMax<uint16_t>(); | |
- } | |
- else if (type == "uint32_t") | |
- { | |
- return getMax<uint32_t>(); | |
- } | |
- else if (type == "uint64_t") | |
- { | |
- return getMax<uint64_t>(); | |
- } | |
- } | |
- | |
void* val_by_datatype(const std::string& type, double val) | |
{ | |
if (type == "char") | |
@@ -251,6 +142,7 @@ namespace ngraph | |
{ | |
return getVal<uint64_t>(val); | |
} | |
+ throw ngraph_error("Cast requested for invalid dtype"); | |
} | |
void* val_by_datatype(const std::string& type, int64_t val) | |
@@ -299,6 +191,7 @@ namespace ngraph | |
{ | |
return getVal<uint64_t>(val); | |
} | |
+ throw ngraph_error("Cast requested for invalid dtype"); | |
} | |
private: | |
diff --git a/src/ngraph/runtime/gpu/type_info.hpp b/src/ngraph/runtime/gpu/type_info.hpp | |
index f98e81c0..ca858cc0 100644 | |
--- a/src/ngraph/runtime/gpu/type_info.hpp | |
+++ b/src/ngraph/runtime/gpu/type_info.hpp | |
@@ -21,6 +21,7 @@ | |
#include <sstream> | |
#include <string> | |
#include <unordered_map> | |
+#include <list> | |
#include "ngraph/type/element_type.hpp" | |
@@ -40,6 +41,10 @@ namespace ngraph | |
virtual std::string min() const = 0; | |
virtual std::string max() const = 0; | |
+ virtual void* lowest_ptr() = 0; | |
+ virtual void* min_ptr() = 0; | |
+ virtual void* max_ptr() = 0; | |
+ | |
using TypeDispatch = std::unordered_map<std::string, std::shared_ptr<TypeInfo>>; | |
static const std::shared_ptr<TypeInfo>& Get(const element::Type& type) | |
{ | |
@@ -80,6 +85,23 @@ namespace ngraph | |
{ | |
return to_string<T>(std::numeric_limits<T>::max()); | |
} | |
+ void* lowest_ptr() override | |
+ { | |
+ values.push_back(std::numeric_limits<T>::lowest()); | |
+ return &values.back(); | |
+ } | |
+ void* min_ptr() override | |
+ { | |
+ values.push_back(std::numeric_limits<T>::min()); | |
+ return &values.back(); | |
+ } | |
+ void* max_ptr() override | |
+ { | |
+ values.push_back(std::numeric_limits<T>::max()); | |
+ return &values.back(); | |
+ } | |
+ private: | |
+ std::list<T> values; | |
}; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment