Skip to content

Instantly share code, notes, and snippets.

@csullivan
Last active November 27, 2018 01:00
Show Gist options
  • Save csullivan/fde426ab32ae1f36b24bdbc70890bf51 to your computer and use it in GitHub Desktop.
Save csullivan/fde426ab32ae1f36b24bdbc70890bf51 to your computer and use it in GitHub Desktop.
diff --git a/src/ngraph/runtime/gpu/cuda_emitter.cpp b/src/ngraph/runtime/gpu/cuda_emitter.cpp
index a9ef0e00..09fe458b 100644
--- a/src/ngraph/runtime/gpu/cuda_emitter.cpp
+++ b/src/ngraph/runtime/gpu/cuda_emitter.cpp
@@ -3096,11 +3096,11 @@ void* runtime::gpu::CUDAEmitter::get_init_reduce_val(std::string reduce_op, std:
{
if (reduce_op == "max")
{
- return m_host_parameters->min_by_datatype(data_type);
+ return TypeInfo::Get(data_type)->max_ptr();
}
else if (reduce_op == "min")
{
- return m_host_parameters->max_by_datatype(data_type);
+ return TypeInfo::Get(data_type)->lowest_ptr();
}
else if (reduce_op == "mul" || reduce_op == "and")
{
diff --git a/src/ngraph/runtime/gpu/gpu_host_parameters.hpp b/src/ngraph/runtime/gpu/gpu_host_parameters.hpp
index 04c297e4..4b93a9d8 100644
--- a/src/ngraph/runtime/gpu/gpu_host_parameters.hpp
+++ b/src/ngraph/runtime/gpu/gpu_host_parameters.hpp
@@ -20,6 +20,9 @@
#include <limits>
#include <list>
+#include "ngraph/except.hpp"
+#include "ngraph/runtime/gpu/type_info.hpp"
+
namespace ngraph
{
namespace runtime
@@ -87,124 +90,12 @@ namespace ngraph
return &m_uint64_t_params.back();
}
- template <typename T>
- void* getMin()
- {
- return cache(std::numeric_limits<T>::has_infinity
- ? -std::numeric_limits<T>::infinity()
- : std::numeric_limits<T>::min());
- }
-
- template <typename T>
- void* getMax()
- {
- return cache(std::numeric_limits<T>::has_infinity
- ? std::numeric_limits<T>::infinity()
- : std::numeric_limits<T>::max());
- }
-
template <typename T1, typename T2>
void* getVal(T2 val)
{
return cache(static_cast<T1>(val));
}
- void* min_by_datatype(const std::string& type)
- {
- if (type == "char")
- {
- return getMin<char>();
- }
- else if (type == "float")
- {
- return getMin<float>();
- }
- else if (type == "double")
- {
- return getMin<double>();
- }
- else if (type == "int8_t")
- {
- return getMin<int8_t>();
- }
- else if (type == "int16_t")
- {
- return getMin<int16_t>();
- }
- else if (type == "int32_t")
- {
- return getMin<int32_t>();
- }
- else if (type == "int64_t")
- {
- return getMin<int64_t>();
- }
- else if (type == "uint8_t")
- {
- return getMin<uint8_t>();
- }
- else if (type == "uint16_t")
- {
- return getMin<uint16_t>();
- }
- else if (type == "uint32_t")
- {
- return getMin<uint32_t>();
- }
- else if (type == "uint64_t")
- {
- return getMin<uint64_t>();
- }
- }
-
- void* max_by_datatype(const std::string& type)
- {
- if (type == "char")
- {
- return getMax<char>();
- }
- else if (type == "float")
- {
- return getMax<float>();
- }
- else if (type == "double")
- {
- return getMax<double>();
- }
- else if (type == "int8_t")
- {
- return getMax<int8_t>();
- }
- else if (type == "int16_t")
- {
- return getMax<int16_t>();
- }
- else if (type == "int32_t")
- {
- return getMax<int32_t>();
- }
- else if (type == "int64_t")
- {
- return getMax<int64_t>();
- }
- else if (type == "uint8_t")
- {
- return getMax<uint8_t>();
- }
- else if (type == "uint16_t")
- {
- return getMax<uint16_t>();
- }
- else if (type == "uint32_t")
- {
- return getMax<uint32_t>();
- }
- else if (type == "uint64_t")
- {
- return getMax<uint64_t>();
- }
- }
-
void* val_by_datatype(const std::string& type, double val)
{
if (type == "char")
@@ -251,6 +142,7 @@ namespace ngraph
{
return getVal<uint64_t>(val);
}
+ throw ngraph_error("Cast requested for invalid dtype");
}
void* val_by_datatype(const std::string& type, int64_t val)
@@ -299,6 +191,7 @@ namespace ngraph
{
return getVal<uint64_t>(val);
}
+ throw ngraph_error("Cast requested for invalid dtype");
}
private:
diff --git a/src/ngraph/runtime/gpu/type_info.hpp b/src/ngraph/runtime/gpu/type_info.hpp
index f98e81c0..ca858cc0 100644
--- a/src/ngraph/runtime/gpu/type_info.hpp
+++ b/src/ngraph/runtime/gpu/type_info.hpp
@@ -21,6 +21,7 @@
#include <sstream>
#include <string>
#include <unordered_map>
+#include <list>
#include "ngraph/type/element_type.hpp"
@@ -40,6 +41,10 @@ namespace ngraph
virtual std::string min() const = 0;
virtual std::string max() const = 0;
+ virtual void* lowest_ptr() = 0;
+ virtual void* min_ptr() = 0;
+ virtual void* max_ptr() = 0;
+
using TypeDispatch = std::unordered_map<std::string, std::shared_ptr<TypeInfo>>;
static const std::shared_ptr<TypeInfo>& Get(const element::Type& type)
{
@@ -80,6 +85,23 @@ namespace ngraph
{
return to_string<T>(std::numeric_limits<T>::max());
}
+ void* lowest_ptr() override
+ {
+ values.push_back(std::numeric_limits<T>::lowest());
+ return &values.back();
+ }
+ void* min_ptr() override
+ {
+ values.push_back(std::numeric_limits<T>::min());
+ return &values.back();
+ }
+ void* max_ptr() override
+ {
+ values.push_back(std::numeric_limits<T>::max());
+ return &values.back();
+ }
+ private:
+ std::list<T> values;
};
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment