Skip to content

Instantly share code, notes, and snippets.

@csullivan
Created September 25, 2018 16:27
Show Gist options
  • Save csullivan/3dca71d2e19d31e058c59c088e2c25c4 to your computer and use it in GitHub Desktop.
Save csullivan/3dca71d2e19d31e058c59c088e2c25c4 to your computer and use it in GitHub Desktop.
diff --git a/src/ngraph/runtime/gpu/CMakeLists.txt b/src/ngraph/runtime/gpu/CMakeLists.txt
index 04d96608..aaad210c 100644
--- a/src/ngraph/runtime/gpu/CMakeLists.txt
+++ b/src/ngraph/runtime/gpu/CMakeLists.txt
@@ -42,7 +42,6 @@ set(SRC
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
pass/gpu_rnn_fusion.cpp
- op/lstm.cpp
op/rnn.cpp
)
diff --git a/src/ngraph/runtime/gpu/cudnn_descriptors.hpp b/src/ngraph/runtime/gpu/cudnn_descriptors.hpp
index 614e4101..89d1ca73 100644
--- a/src/ngraph/runtime/gpu/cudnn_descriptors.hpp
+++ b/src/ngraph/runtime/gpu/cudnn_descriptors.hpp
@@ -135,6 +135,7 @@ namespace ngraph
}
};
+#if CUDNN_VERSION >= 7200
template <>
struct cudnn_descriptor<cudnnRNNDataDescriptor_t>
{
@@ -147,7 +148,7 @@ namespace ngraph
CUDNN_SAFE_CALL_NO_THROW(cudnnDestroyRNNDataDescriptor(desc));
}
};
-
+#endif
template <>
struct cudnn_descriptor<cudnnPoolingDescriptor_t>
{
diff --git a/src/ngraph/runtime/gpu/cudnn_emitter.cpp b/src/ngraph/runtime/gpu/cudnn_emitter.cpp
index 31a8b40e..b63d89f8 100644
--- a/src/ngraph/runtime/gpu/cudnn_emitter.cpp
+++ b/src/ngraph/runtime/gpu/cudnn_emitter.cpp
@@ -945,6 +945,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::Min* node)
return primitive_index;
}
+#if CUDNN_VERSION >= 7200
size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::gpu::Rnn* node)
{
auto& args = node->get_inputs();
@@ -952,7 +953,7 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::gpu::Rnn* node)
auto dtype = out[0].get_element_type().c_type_string();
std::stringstream ss;
- ss << "rnn_";
+ ss << "rnn_psz" << shape_size(args[2].get_shape());
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
@@ -1134,61 +1135,48 @@ size_t runtime::gpu::CUDNNEmitter::build_primitive(const op::gpu::Rnn* node)
std::vector<cudnnTensorDescriptor_t> seq_descriptors(seq_length, temp_input_desc);
CUDNN_SAFE_CALL(cudnnGetRNNWorkspaceSize(
*m_ctx->cudnn_handle, rnn_desc, seq_length, seq_descriptors.data(), &workspace_size));
+
size_t workspace_idx = allocator.reserve_workspace(workspace_size);
- auto wx_size = args[1].get_element_type().size() * shape_size(args[1].get_shape());
- auto wh_size = args[3].get_element_type().size() * shape_size(args[3].get_shape());
- auto bx_size = args[4].get_element_type().size() * shape_size(args[4].get_shape());
- auto bh_size = args[5].get_element_type().size() * shape_size(args[5].get_shape());
auto recurrent_index = num_tensors_per_layer / 2;
- std::unique_ptr<gpu::primitive> kernel_launch(new gpu::primitive{[=](void** inputs,
- void** outputs) {
- void* w_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, w_idx);
- void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
-
- // pack the weight and bias parameter data
- cuda_memcpyDtD(static_cast<uint8_t*>(w_ptr) + weight_offsets[0].first, inputs[1], wx_size);
- cuda_memcpyDtD(static_cast<uint8_t*>(w_ptr) + weight_offsets[recurrent_index].first,
- inputs[3],
- wh_size);
- cuda_memcpyDtD(static_cast<uint8_t*>(w_ptr) + bias_offsets[0].first, inputs[4], bx_size);
- cuda_memcpyDtD(
- static_cast<uint8_t*>(w_ptr) + bias_offsets[recurrent_index].first, inputs[4], bh_size);
-
- CUDNN_SAFE_CALL(cudnnRNNForwardInferenceEx(*m_ctx->cudnn_handle,
- rnn_desc,
- x_desc,
- inputs[0],
- hx_desc,
- inputs[2],
- cx_desc,
- inputs[6],
- w_desc,
- w_ptr,
- y_desc, // h_i
- outputs[0],
- hy_desc, // h_t
- outputs[1],
- cy_desc, // c_t
- outputs[2],
- NULL,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL,
- workspace_ptr,
- workspace_size));
- debug_sync();
- }});
+ std::unique_ptr<gpu::primitive> kernel_launch(
+ new gpu::primitive{[=](void** inputs, void** outputs) {
+ void* workspace_ptr = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
+ CUDNN_SAFE_CALL(cudnnRNNForwardInferenceEx(*m_ctx->cudnn_handle,
+ rnn_desc,
+ x_desc,
+ inputs[0],
+ hx_desc,
+ inputs[1],
+ cx_desc,
+ inputs[3],
+ w_desc,
+ inputs[2],
+ y_desc, // h_i
+ outputs[0],
+ hy_desc, // h_t
+ outputs[1],
+ cy_desc, // c_t
+ outputs[2],
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ workspace_ptr,
+ workspace_size));
+ debug_sync();
+ }});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
+#endif
size_t runtime::gpu::CUDNNEmitter::build_convolution(const std::string& dtype,
const Shape& input_tensor_shape,
diff --git a/src/ngraph/runtime/gpu/cudnn_emitter.hpp b/src/ngraph/runtime/gpu/cudnn_emitter.hpp
index 02f36efb..32d7f3d7 100644
--- a/src/ngraph/runtime/gpu/cudnn_emitter.hpp
+++ b/src/ngraph/runtime/gpu/cudnn_emitter.hpp
@@ -35,7 +35,6 @@
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
-#include "ngraph/runtime/gpu/op/lstm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
namespace ngraph
diff --git a/src/ngraph/runtime/gpu/gpu_emitter.cpp b/src/ngraph/runtime/gpu/gpu_emitter.cpp
index 33b1b60c..bfd759d6 100644
--- a/src/ngraph/runtime/gpu/gpu_emitter.cpp
+++ b/src/ngraph/runtime/gpu/gpu_emitter.cpp
@@ -123,7 +123,7 @@ function<void(EMIT_ARGS)> runtime::gpu::GPU_Emitter::get_emit_function(const Nod
// ...
#define NGRAPH_OP(a, b) {type_index(typeid(b::a)), runtime::gpu::GPU_Emitter::emit_##a},
static const map<type_index, function<void(EMIT_ARGS)>> typeid_map{
-#include "ngraph/op/op_tbl.hpp"
+#include "ngraph/runtime/gpu/op/op_tbl.hpp"
};
#undef NGRAPH_OP
auto it = typeid_map.find(type_index(typeid(node)));
@@ -1399,6 +1399,24 @@ void runtime::gpu::GPU_Emitter::emit_ReverseSequence(EMIT_ARGS)
writer.block_end();
}
+#if CUDNN_VERSION >= 7200
+void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
+{
+ auto rnn = static_cast<const ngraph::op::gpu::Rnn*>(node);
+
+ auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
+ size_t index = cudnn_emitter->build_primitive(rnn);
+
+ writer.block_begin();
+ {
+ writer << "void* input[] = {" << node_names(args) << "};\n";
+ writer << "void* output[] = {" << node_names(out) << "};\n";
+ writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
+ }
+ writer.block_end();
+}
+#endif
+
void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out);
@@ -1548,31 +1566,6 @@ void runtime::gpu::GPU_Emitter::emit_Sum(EMIT_ARGS)
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << sum_index << ", input, output);\n";
}
-
- template <>
- void GPU_Emitter::EMITTER_DECL(ngraph::op::gpu::Rnn)
- {
- auto rnn = static_cast<const ngraph::op::gpu::Rnn*>(node);
-
- auto& cudnn_emitter =
- external_function->get_primitive_emitter()->get_cudnn_emitter();
- size_t rnn_index = cudnn_emitter->build_primitive(rnn);
-
- writer << "gpu::invoke_primitive(ctx, " << rnn_index << ", ";
- writer << "std::vector<void*>{" << args.front().get_name();
- for (size_t i = 1; i < args.size(); i++)
- {
- writer << ", " << args[i].get_name();
- }
- writer << "}.data(), ";
- writer << "std::vector<void*>{" << out.front().get_name();
- for (size_t i = 1; i < out.size(); i++)
- {
- writer << ", " << out[i].get_name();
- }
- writer << "}.data()";
- writer << ");\n";
- }
}
}
writer.block_end();
diff --git a/src/ngraph/runtime/gpu/gpu_emitter.hpp b/src/ngraph/runtime/gpu/gpu_emitter.hpp
index a663bbaa..dca1741f 100644
--- a/src/ngraph/runtime/gpu/gpu_emitter.hpp
+++ b/src/ngraph/runtime/gpu/gpu_emitter.hpp
@@ -39,7 +39,7 @@ namespace ngraph
// static void emit_Abs(EMIT_ARGS);
// static void emit_Acos(EMIT_ARGS);
#define NGRAPH_OP(a, b) static void emit_##a(EMIT_ARGS);
-#include "ngraph/op/op_tbl.hpp"
+#include "ngraph/runtime/gpu/op/op_tbl.hpp"
#undef NGRAPH_OP
template <typename T>
diff --git a/src/ngraph/runtime/gpu/gpu_external_function.cpp b/src/ngraph/runtime/gpu/gpu_external_function.cpp
index bd2bc8fc..32375747 100644
--- a/src/ngraph/runtime/gpu/gpu_external_function.cpp
+++ b/src/ngraph/runtime/gpu/gpu_external_function.cpp
@@ -112,7 +112,6 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view_wrapper.hpp"
-#include "ngraph/runtime/gpu/op/lstm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
@@ -569,27 +568,23 @@ void runtime::gpu::GPU_ExternalFunction::compile()
auto allocator = std::make_shared<runtime::gpu::GPUAllocator>(
m_shared_context->m_primitive_emitter->get_memory_allocator());
+#if CUDNN_VERSION >= 7200
// recurrent network fusion
m_pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
-
m_pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
-
m_pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
-
m_pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
-
+#else
+ m_pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
+#endif
m_pass_manager.register_pass<ngraph::pass::LikeReplacement>();
m_pass_manager
.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
-
m_pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
m_pass_manager.register_pass<ngraph::pass::Liveness>();
-
m_pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
-
m_pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
allocator, m_tensor_memory_buffers);
-
std::string common_function_string;
auto femitter = bind(&ngraph::runtime::gpu::GPU_ExternalFunction::emit_op_as_function,
this,
@@ -597,7 +592,6 @@ void runtime::gpu::GPU_ExternalFunction::compile()
placeholders::_2);
m_pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, m_node_function_map, common_function_string);
-
string dump_filename = file_util::path_join(s_output_dir, m_function_name + "_ops.txt");
m_pass_manager.register_pass<ngraph::pass::DumpSorted>(dump_filename);
diff --git a/src/ngraph/runtime/gpu/op/lstm.cpp b/src/ngraph/runtime/gpu/op/lstm.cpp
deleted file mode 100644
index b45213f1..00000000
--- a/src/ngraph/runtime/gpu/op/lstm.cpp
+++ /dev/null
@@ -1,158 +0,0 @@
-//*****************************************************************************
-// Copyright 2018 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#include "ngraph/runtime/gpu/op/lstm.hpp"
-#include "ngraph/log.hpp"
-#include "ngraph/util.hpp"
-
-using namespace std;
-using namespace ngraph;
-
-shared_ptr<Node> op::gpu::Lstm::copy_with_new_args(const NodeVector& new_args) const
-{
- if (!m_fused_inputs)
- {
- NGRAPH_ASSERT(new_args.size() == 7) << "Incorrect number of new arguments";
-
- return make_shared<Lstm>(new_args.at(0),
- new_args.at(1),
- new_args.at(2),
- new_args.at(3),
- new_args.at(4),
- new_args.at(5),
- new_args.at(6));
- }
- else
- {
- NGRAPH_ASSERT(new_args.size() == 5 && m_fused_inputs)
- << "Incorrect number of new arguments";
-
- return make_shared<Lstm>(
- new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
- }
-}
-
-op::gpu::Lstm::Lstm(std::shared_ptr<Node> input_xt_1,
- std::shared_ptr<Node> i2h_weights,
- std::shared_ptr<Node> hidden_state_ht_1,
- std::shared_ptr<Node> h2h_weights,
- std::shared_ptr<Node> i2h_bias,
- std::shared_ptr<Node> h2h_bias,
- std::shared_ptr<Node> cell_state_ct_1)
- : Op("Lstm",
- {input_xt_1,
- i2h_weights,
- hidden_state_ht_1,
- h2h_weights,
- i2h_bias,
- h2h_bias,
- cell_state_ct_1})
- , m_output_tensor_shape(hidden_state_ht_1->get_shape())
- , m_output_cell_shape(cell_state_ct_1->get_shape())
- , m_num_timesteps(1)
- , m_num_gates_per_cell(4)
- , m_src_sequence_length(1)
- , m_src_layer_feature_size(static_cast<int>(input_xt_1->get_shape()[1]))
- , m_src_iter_feature_size(static_cast<int>(hidden_state_ht_1->get_shape()[1]))
- , m_num_cell_states(2)
- , m_direction(1)
- , m_num_fused_layers(1)
- , m_fused_inputs(false)
-{
- NGRAPH_ASSERT(input_xt_1->get_shape().size() == i2h_weights->get_shape().size())
- << "input_xt_1 and i2h weights size dont match";
- NGRAPH_ASSERT(hidden_state_ht_1->get_shape().size() == h2h_weights->get_shape().size())
- << "hidden_state_ht_1 and h2h weights size dont match";
- NGRAPH_ASSERT(input_xt_1->get_shape().size() == 2) << "input_xt_1 doesnt have a rank 2";
-
- m_batch_size = static_cast<int>(input_xt_1->get_shape()[0]);
-
- NGRAPH_ASSERT(shape_size(input_xt_1->get_shape()) ==
- m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
- << "input_xt_1 size is not equal t*n*c";
- NGRAPH_ASSERT(i2h_bias->get_shape()[0] == i2h_weights->get_shape()[0] &&
- h2h_bias->get_shape()[0] == h2h_weights->get_shape()[0])
- << "bias and weights_shape are not compatible";
-
- NGRAPH_ASSERT(m_output_tensor_shape == m_output_cell_shape)
- << "shape of recurrent input and cell state are not the same";
-
- auto et = input_xt_1->get_element_type();
- for (auto& lstm_input : get_arguments())
- {
- if (lstm_input->get_element_type() != et)
- {
- throw ngraph_error("all rnn inputs must have the same element type");
- }
- }
-
- set_output_size(2);
- set_output_type(0, hidden_state_ht_1->get_element_type(), hidden_state_ht_1->get_shape());
- set_output_type(1, cell_state_ct_1->get_element_type(), cell_state_ct_1->get_shape());
-}
-
-op::gpu::Lstm::Lstm(std::shared_ptr<Node> src_layer,
- std::shared_ptr<Node> src_iter,
- std::shared_ptr<Node> weights_layer,
- std::shared_ptr<Node> weights_iter,
- std::shared_ptr<Node> bias)
- : Op("Lstm", {src_layer, src_iter, weights_layer, weights_iter, bias})
- , m_output_tensor_shape(src_layer->get_shape())
- , m_output_cell_shape(src_iter->get_shape())
- , m_num_timesteps(1)
- , m_num_gates_per_cell(4)
- , m_src_sequence_length(1)
- , m_src_layer_feature_size(static_cast<int>(src_layer->get_shape()[1]))
- , m_src_iter_feature_size(static_cast<int>(src_iter->get_shape()[1]))
- , m_num_cell_states(2)
- , m_direction(1)
- , m_num_fused_layers(1)
- , m_fused_inputs(true)
-{
- NGRAPH_ASSERT(src_layer->get_shape().size() == weights_layer->get_shape().size())
- << "src_layer and i2h weights size dont match";
- NGRAPH_ASSERT(src_iter->get_shape().size() == weights_iter->get_shape().size())
- << "src_iter and h2h weights size dont match";
- NGRAPH_ASSERT(src_layer->get_shape().size() == 2) << "src_layer doesnt have a rank 2";
-
- m_batch_size = static_cast<int>(src_layer->get_shape()[0] / m_num_timesteps);
-
- NGRAPH_ASSERT(shape_size(src_layer->get_shape()) ==
- m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
- << "src_layer size is not equal t*n*c";
- NGRAPH_ASSERT(bias->get_shape()[0] == weights_layer->get_shape()[0] &&
- bias->get_shape()[0] == weights_iter->get_shape()[0])
- << "bias and weights_shape are not compatible";
-
- auto et = src_layer->get_element_type();
- for (auto& rnn_input : get_arguments())
- {
- if (rnn_input->get_element_type() != et)
- {
- throw ngraph_error("all rnn inputs must have the same element type");
- }
- }
-
- set_output_size(2);
- set_output_type(0,
- src_layer->get_element_type(),
- Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size),
- static_cast<unsigned long>(m_src_iter_feature_size)});
- set_output_type(1,
- src_layer->get_element_type(),
- Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size),
- static_cast<unsigned long>(m_src_iter_feature_size)});
-}
diff --git a/src/ngraph/runtime/gpu/op/lstm.hpp b/src/ngraph/runtime/gpu/op/lstm.hpp
deleted file mode 100644
index 2de711f5..00000000
--- a/src/ngraph/runtime/gpu/op/lstm.hpp
+++ /dev/null
@@ -1,103 +0,0 @@
-//*****************************************************************************
-// Copyright 2018 Intel Corporation
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//*****************************************************************************
-
-#pragma once
-
-#include "ngraph/op/op.hpp"
-#include "ngraph/util.hpp"
-
-namespace ngraph
-{
- namespace op
- {
- namespace gpu
- {
- class Lstm : public Op
- {
- public:
- // INPUTS:
- // [0] - xt, input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
- // [1] - initializer for the input weights matrix, used for the linear transformation of the inputs.
- // [2] - ht_1, hidden state of shape (batch_size, feature_size)
- // [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
- // [4] - Initializer for the bias vector w.r.to inputs.
- // [5] - Initializer for the bias vector w.r.to hidden state
- // [6] - ct_1, cell state of shape (batch_size, feature_size)
-
- // OUTPUT VALUE: A tuple with the following structure:
- // [0] - ht, output tensor with shape (sequence_length*batch_size, num_hidden) .
- // [1] - ct, output recurrent state tensor with the same shape as cell state
-
- // This version of the LSTM op is only used to simplify recurrent RNN cell(LSTM) fusion across
- // horizontal time steps.
- Lstm(std::shared_ptr<Node> input_xt_1,
- std::shared_ptr<Node> i2h_weights,
- std::shared_ptr<Node> hidden_state_ht_1,
- std::shared_ptr<Node> h2h_weights,
- std::shared_ptr<Node> i2h_bias,
- std::shared_ptr<Node> h2h_bias,
- std::shared_ptr<Node> cell_state_ct_1);
-
- // INPUTS:
- // [0] - {Xt} input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
- // [1] - recurrent state tensors {ht_1 | ct_1} of Shape{sequence length*batch_size, feature_size}
- // [2] - initializer for the input weights matrix, used for the linear transformation of the inputs.
- // [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
- // [4] - Initializer for the bias vector w.r.to inputs + hidden state (ibh_bias + hbh_bias)
-
- // OUTPUT VALUE: A tuple with the following structure:
- // [0] - ht, output tensor with shape (sequence_length*batch_size, num_hidden) .
- // [1] - {ht | ct} output recurrent state tensor with the same shape as states
-
- // This version of the LSTM op supports emitter code, this can be used standalone for computing RNN
- // without fusing RNN cell (LSTM)'s across time steps.
- Lstm(std::shared_ptr<Node> src_layer,
- std::shared_ptr<Node> src_iter,
- std::shared_ptr<Node> weights_layer,
- std::shared_ptr<Node> weights_iter,
- std::shared_ptr<Node> bias);
- Shape get_output_tensor_shape() const { return m_output_tensor_shape; }
- Shape get_output_cell_shape() const { return m_output_cell_shape; }
- int get_num_timesteps() const { return m_num_timesteps; }
- int get_src_sequence_length() const { return m_src_sequence_length; }
- int get_gates_per_cell() const { return m_num_gates_per_cell; }
- int get_batch_size() const { return m_batch_size; }
- int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
- int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
- int get_num_cell_states() const { return m_num_cell_states; }
- int get_direction() const { return m_direction; }
- int get_num_fused_layers() const { return m_num_fused_layers; }
- int get_fused_inputs() const { return m_fused_inputs; }
- virtual std::shared_ptr<Node>
- copy_with_new_args(const NodeVector& new_args) const override;
-
- private:
- Shape m_output_tensor_shape;
- Shape m_output_cell_shape;
- int m_num_timesteps;
- int m_num_gates_per_cell;
- int m_src_sequence_length;
- int m_batch_size;
- int m_src_layer_feature_size;
- int m_src_iter_feature_size;
- int m_num_cell_states;
- int m_direction;
- int m_num_fused_layers;
- bool m_fused_inputs; // True if node gets fused inputs/weights
- };
- }
- }
-}
diff --git a/src/ngraph/runtime/gpu/op/op_tbl.hpp b/src/ngraph/runtime/gpu/op/op_tbl.hpp
new file mode 100644
index 00000000..e52031ed
--- /dev/null
+++ b/src/ngraph/runtime/gpu/op/op_tbl.hpp
@@ -0,0 +1,20 @@
+//*****************************************************************************
+// Copyright 2017-2018 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "ngraph/op/op_tbl.hpp"
+#if CUDNN_VERSION >= 7200
+NGRAPH_OP(Rnn, ngraph::op::gpu)
+#endif
diff --git a/src/ngraph/runtime/gpu/op/rnn.cpp b/src/ngraph/runtime/gpu/op/rnn.cpp
index d2de218e..10dcc7c0 100644
--- a/src/ngraph/runtime/gpu/op/rnn.cpp
+++ b/src/ngraph/runtime/gpu/op/rnn.cpp
@@ -23,15 +23,12 @@ using namespace ngraph;
shared_ptr<Node> op::gpu::Rnn::copy_with_new_args(const NodeVector& new_args) const
{
- NGRAPH_ASSERT(new_args.size() == 7) << "Incorrect number of new arguments";
+ NGRAPH_ASSERT(new_args.size() == 4) << "Incorrect number of new arguments";
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
- new_args[4],
- new_args[5],
- new_args[6],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
@@ -43,10 +40,7 @@ shared_ptr<Node> op::gpu::Rnn::copy_with_new_args(const NodeVector& new_args) co
op::gpu::Rnn::Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
- std::shared_ptr<Node> weights_layer,
- std::shared_ptr<Node> weights_iter,
- std::shared_ptr<Node> bias_layer,
- std::shared_ptr<Node> bias_iter,
+ std::shared_ptr<Node> params,
std::shared_ptr<Node> state_iter,
const int num_timesteps,
const int num_gates_per_cell,
@@ -55,8 +49,7 @@ op::gpu::Rnn::Rnn(std::shared_ptr<Node> src_layer,
const int src_iter_feature_size,
const int direction,
const int num_fused_layers)
- : Op("Rnn",
- {src_layer, src_iter, weights_layer, weights_iter, bias_layer, bias_iter, state_iter})
+ : Op("Rnn", {src_layer, src_iter, params, state_iter})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
@@ -65,10 +58,6 @@ op::gpu::Rnn::Rnn(std::shared_ptr<Node> src_layer,
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
{
- NGRAPH_ASSERT(src_layer->get_shape().size() == weights_layer->get_shape().size())
- << "src_layer and i2h weights size dont match";
- NGRAPH_ASSERT(src_iter->get_shape().size() == weights_iter->get_shape().size())
- << "src_iter and h2h weights size dont match";
NGRAPH_ASSERT(src_layer->get_shape().size() == 2) << "src_layer doesnt have a rank 2";
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / num_timesteps);
@@ -76,11 +65,6 @@ op::gpu::Rnn::Rnn(std::shared_ptr<Node> src_layer,
NGRAPH_ASSERT(shape_size(src_layer->get_shape()) ==
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
<< "src_layer size is not equal t*n*c";
- NGRAPH_ASSERT(bias_layer->get_shape() == bias_iter->get_shape())
- << "bias tensor shapes do not match";
- NGRAPH_ASSERT(bias_layer->get_shape()[0] == weights_layer->get_shape()[0] &&
- bias_iter->get_shape()[0] == weights_iter->get_shape()[0])
- << "bias and weights_shape are not compatible";
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
diff --git a/src/ngraph/runtime/gpu/op/rnn.hpp b/src/ngraph/runtime/gpu/op/rnn.hpp
index 5dbc9ddf..842fc76b 100644
--- a/src/ngraph/runtime/gpu/op/rnn.hpp
+++ b/src/ngraph/runtime/gpu/op/rnn.hpp
@@ -29,12 +29,11 @@ namespace ngraph
// across multiple time slices
// INPUTS:
- // [0] - {X0, X1...., Xt} input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
- // [1] - recurrent state tensors {ht_1 | ct_1} of Shape{sequence length*batch_size, feature_size}
- // [2] - initializer for the input weights matrix, used for the linear transformation of the inputs.
- // [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
- // [4] - Initializer for the bias vector w.r.to inputs
- // [5] - Initializer for the bias vector w.r.to hidden state
+ // [0] - {X0, X1...., Xt} input tensor of layout TNC, Shape{num_fused_layers*batch_size, feature_size}
+ // [1] - recurrent input tensor ht_1 of Shape{sequence length*batch_size, feature_size}
+ // [2] - flat parameter tensor consisting of weights and biases for each layer
+ // {W_x^0 | W_h^0 | W_x^1 | W_h^1 | ... | B_x^0 | B_h^0 | B_x^1 | B_h^1 }
+ // [3] - recurrent cell state tensor ct_1 with same shape as ht_1
// number_of_timesteps - number of unrolled cells up to timestep t.
// num_gates_per_cell - number of gates per RNN cell, LSTM = 4, GRU = 3, vanilla RNN = 1
// src_sequence_length - this will be same as number_of_timesteps
@@ -43,19 +42,17 @@ namespace ngraph
// num_cell_states - number of recurrent state tensor states , LSTM = 2, GRU = 1, vanilla RNN = 1
// OUTPUT VALUE: A tuple with the following structure:
- // [0] - ht, output tensor with shape (sequence_length*batch_size, feature_size) .
- // [1] - {ht | ct} output recurrent state tensor with the same shape as states i.e (sequence_length*batch_size, feature_size)
+ // [0] - ht, sequence-wise output tensor with shape (sequence_length*batch_size, feature_size) .
+ // [1] - hf, layer-wise output tensor with shape (num_fused_layers*batch_size, feature_size) .
+ // [2] - ct output cell state tensor with the same shape as states i.e (sequence_length*batch_size, feature_size)
class Rnn : public Op
{
public:
- Rnn(std::shared_ptr<Node> src_layer, // x
- std::shared_ptr<Node> src_iter, // hx
- std::shared_ptr<Node> weights_layer, // wx
- std::shared_ptr<Node> weights_iter, // whx
- std::shared_ptr<Node> bias_layer, // bx
- std::shared_ptr<Node> bias_iter, // bhx
- std::shared_ptr<Node> state_iter, // cx
+ Rnn(std::shared_ptr<Node> src_layer, // x
+ std::shared_ptr<Node> src_iter, // hx
+ std::shared_ptr<Node> params,
+ std::shared_ptr<Node> state_iter, // cx
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
diff --git a/src/ngraph/runtime/gpu/pass/gpu_rnn_fusion.cpp b/src/ngraph/runtime/gpu/pass/gpu_rnn_fusion.cpp
index 6b00040c..29384380 100644
--- a/src/ngraph/runtime/gpu/pass/gpu_rnn_fusion.cpp
+++ b/src/ngraph/runtime/gpu/pass/gpu_rnn_fusion.cpp
@@ -45,9 +45,15 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
-#include "ngraph/runtime/gpu/op/lstm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
+#define RETURN_IF_FALSE(cond, message) \
+ if (!(cond)) \
+ { \
+ NGRAPH_DEBUG << message; \
+ return false; \
+ }
+
using namespace ngraph;
void ngraph::runtime::gpu::pass::LSTMFusion::construct_sigmoid()
{
@@ -92,6 +98,34 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_sigmoid()
this->add_matcher(m);
}
+static std::shared_ptr<Node> compute_lstm_params(const std::shared_ptr<Node>& w_x,
+ const std::shared_ptr<Node>& w_h,
+ const std::shared_ptr<Node>& b_x,
+ const std::shared_ptr<Node>& b_h)
+{
+ // check if concat of params exists already
+ // if so, use it
+ for (auto& node : w_x->get_users())
+ {
+ for (auto& possible_concat : node->get_users())
+ {
+ if (auto concat = std::dynamic_pointer_cast<op::Concat>(possible_concat))
+ {
+ return concat;
+ }
+ }
+ }
+
+ NodeVector flat_params;
+ for (auto& param : NodeVector{w_x, w_h, b_x, b_h})
+ {
+ auto shape = param->get_shape();
+ flat_params.push_back(std::make_shared<op::Reshape>(
+ param, get_default_order(shape), Shape{shape_size(shape)}));
+ }
+ return std::make_shared<op::Concat>(flat_params, 0);
+}
+
void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
{
auto input_xt = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
@@ -174,19 +208,17 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
return false;
}
- NGRAPH_ASSERT(bias_i2h->get_shape().size() == 1 && bias_h2h->get_shape().size() == 1)
- << "Bias should have rank of 1 for Rnn op";
+ RETURN_IF_FALSE(bias_i2h->get_shape().size() == 1 && bias_h2h->get_shape().size() == 1,
+ "Bias should have rank of 1 for Rnn op");
// Determine which is ht_1 and xt. but if both xt and ht_1 have the same shape we need to capture this
// reliably in the RNN fusion.
- std::shared_ptr<op::gpu::Lstm> lstm = nullptr;
+ std::shared_ptr<op::gpu::Rnn> lstm = nullptr;
bool intermediate_lstm = false;
-
if (std::dynamic_pointer_cast<op::GetOutputElement>(pattern_map[ct_1]))
{
intermediate_lstm = true;
}
-
// if the matched LSTM is the first cell we need to check if symbol input_xt corresponds
// to the input data tensor, or the hidden (recurrent) data tensor
if (!intermediate_lstm &&
@@ -194,26 +226,42 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
std::dynamic_pointer_cast<op::Constant>(pattern_map[hidden_ht]->get_argument(0))))
// label input_xt is the input data to the first LSTM
{
- lstm = std::make_shared<op::gpu::Lstm>(pattern_map[input_xt],
- pattern_map[weights_i2h],
- pattern_map[hidden_ht],
- pattern_map[weights_h2h],
- pattern_map[bias_i2h],
- pattern_map[bias_h2h],
- pattern_map[ct_1]);
+ auto params = compute_lstm_params(pattern_map[weights_i2h],
+ pattern_map[weights_h2h],
+ pattern_map[bias_i2h],
+ pattern_map[bias_h2h]);
+ lstm = std::make_shared<op::gpu::Rnn>(pattern_map[input_xt],
+ pattern_map[hidden_ht],
+ params,
+ pattern_map[ct_1],
+ 1,
+ 4,
+ 1,
+ pattern_map[input_xt]->get_shape()[1],
+ pattern_map[hidden_ht]->get_shape()[1],
+ 1,
+ 1);
}
else if (!intermediate_lstm &&
(std::dynamic_pointer_cast<op::Broadcast>(pattern_map[input_xt]) &&
std::dynamic_pointer_cast<op::Constant>(pattern_map[input_xt]->get_argument(0))))
// label hidden_ht is the input data to the first LSTM
{
- lstm = std::make_shared<op::gpu::Lstm>(pattern_map[hidden_ht],
- pattern_map[weights_h2h],
- pattern_map[input_xt],
- pattern_map[weights_i2h],
- pattern_map[bias_h2h],
- pattern_map[bias_i2h],
- pattern_map[ct_1]);
+ auto params = compute_lstm_params(pattern_map[weights_h2h],
+ pattern_map[weights_i2h],
+ pattern_map[bias_h2h],
+ pattern_map[bias_i2h]);
+ lstm = std::make_shared<op::gpu::Rnn>(pattern_map[hidden_ht],
+ pattern_map[input_xt],
+ params,
+ pattern_map[ct_1],
+ 1,
+ 4,
+ 1,
+ pattern_map[hidden_ht]->get_shape()[1],
+ pattern_map[input_xt]->get_shape()[1],
+ 1,
+ 1);
}
else if (pattern_map[hidden_ht]->get_arguments().size() &&
pattern_map[ct_1]->get_arguments().at(0)->get_instance_id() ==
@@ -224,30 +272,46 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
// label input_xt is the output data from the previous LSTM cell
NGRAPH_DEBUG << "ct_shape : " << join(pattern_map[ct_1]->get_shape())
<< " hidden state shape: " << join(pattern_map[hidden_ht]->get_shape());
- lstm = std::make_shared<op::gpu::Lstm>(pattern_map[input_xt],
- pattern_map[weights_i2h],
- pattern_map[hidden_ht],
- pattern_map[weights_h2h],
- pattern_map[bias_i2h],
- pattern_map[bias_h2h],
- pattern_map[ct_1]);
+ auto params = compute_lstm_params(pattern_map[weights_i2h],
+ pattern_map[weights_h2h],
+ pattern_map[bias_i2h],
+ pattern_map[bias_h2h]);
+ lstm = std::make_shared<op::gpu::Rnn>(pattern_map[input_xt],
+ pattern_map[hidden_ht],
+ params,
+ pattern_map[ct_1],
+ 1,
+ 4,
+ 1,
+ pattern_map[input_xt]->get_shape()[1],
+ pattern_map[hidden_ht]->get_shape()[1],
+ 1,
+ 1);
}
else
{
// label hidden_ht is the output data from the previous LSTM cell
NGRAPH_DEBUG << "ct_shape: " << join(pattern_map[ct_1]->get_shape())
<< " hidden state shape: " << join(pattern_map[input_xt]->get_shape());
- lstm = std::make_shared<op::gpu::Lstm>(pattern_map[hidden_ht],
- pattern_map[weights_h2h],
- pattern_map[input_xt],
- pattern_map[weights_i2h],
- pattern_map[bias_h2h],
- pattern_map[bias_i2h],
- pattern_map[ct_1]);
+ auto params = compute_lstm_params(pattern_map[weights_h2h],
+ pattern_map[weights_i2h],
+ pattern_map[bias_h2h],
+ pattern_map[bias_i2h]);
+ lstm = std::make_shared<op::gpu::Rnn>(pattern_map[hidden_ht],
+ pattern_map[input_xt],
+ params,
+ pattern_map[ct_1],
+ 1,
+ 4,
+ 1,
+ pattern_map[hidden_ht]->get_shape()[1],
+ pattern_map[input_xt]->get_shape()[1],
+ 1,
+ 1);
}
auto ht_output = std::make_shared<op::GetOutputElement>(lstm, 0);
- auto ct_output = std::make_shared<op::GetOutputElement>(lstm, 1);
+ auto ct_output = std::make_shared<op::GetOutputElement>(lstm, 2);
NGRAPH_ASSERT(lstm->get_outputs().at(0).get_inputs().size() == 2)
<< "Lstm node doesnt have two outputs";
@@ -309,239 +373,233 @@ static std::shared_ptr<ngraph::Node>
void ngraph::runtime::gpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{
- auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
- auto weights_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto xt = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
- auto weights_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
- auto bias_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
- auto bias_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
+ auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
+ auto params_label = std::make_shared<pattern::op::Label>(
+ element::f32, Shape{400 * 100 + 400 * 100 + 400 + 400});
auto rpattern_ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
- auto lstm = std::make_shared<op::gpu::Lstm>(
- xt, weights_i2h, ht_1, weights_h2h, bias_i2h, bias_h2h, rpattern_ct_1);
+ auto lstm = std::make_shared<op::gpu::Rnn>(xt,
+ ht_1,
+ params_label,
+ rpattern_ct_1,
+ 1,
+ 4,
+ 1,
+ xt->get_shape()[1],
+ ht_1->get_shape()[1],
+ 1,
+ 1);
auto goe = std::make_shared<op::GetOutputElement>(lstm, 0); // hidden output
auto lstm_node_label = std::make_shared<pattern::op::Label>(goe, nullptr, NodeVector{goe});
- pattern::recurrent_graph_rewrite_callback callback =
- [lstm_node_label, xt, weights_h2h, ht_1, weights_i2h, bias_i2h, bias_h2h, rpattern_ct_1](
- pattern::RecurrentMatcher& m) {
+ pattern::recurrent_graph_rewrite_callback callback = [lstm_node_label,
+ xt,
+ ht_1,
+ params_label,
+ rpattern_ct_1](
+ pattern::RecurrentMatcher& m) {
- NGRAPH_DEBUG << " In RNN fusion callback";
+ NGRAPH_DEBUG << " In RNN fusion callback";
- auto ht_1_label = m.get_bound_nodes_for_pattern(ht_1);
+ auto ht_1_label = m.get_bound_nodes_for_pattern(ht_1);
+ auto params_bound = m.get_bound_nodes_for_pattern(params_label);
- // determine the ht and xt
- std::shared_ptr<ngraph::Node> src_layer = nullptr;
- std::shared_ptr<ngraph::Node> src_iter = nullptr;
+ // determine the ht and xt
+ std::shared_ptr<ngraph::Node> src_layer = nullptr;
+ std::shared_ptr<ngraph::Node> src_iter = nullptr;
- auto xt_node_array = m.get_bound_nodes_for_pattern(xt);
- auto hidden_ht_array = m.get_bound_nodes_for_pattern(ht_1);
+ auto xt_node_array = m.get_bound_nodes_for_pattern(xt);
+ auto hidden_ht_array = m.get_bound_nodes_for_pattern(ht_1);
- // since we dont have metadata to differentiate between xt and ht_1
- // we will be using the broadcasted constant initilization of the first LSTM cell
- // in the RNN layer to identify ht_1
- if (std::dynamic_pointer_cast<op::Broadcast>(xt_node_array[xt_node_array.size() - 1]) &&
- std::dynamic_pointer_cast<op::Constant>(
- xt_node_array[xt_node_array.size() - 1]->get_argument(0)))
- // here xt is determined to be the hidden (recurrent) input data and so ht is the feedforward input
- {
- // concatenate the sequence inputs for a given layer
- std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{ht_1};
- src_layer = compute_rnn_args(src_layer_labels, m, true);
+ // since we dont have metadata to differentiate between xt and ht_1
+ // we will be using the broadcasted constant initilization of the first LSTM cell
+ // in the RNN layer to identify ht_1
+ if (std::dynamic_pointer_cast<op::Broadcast>(xt_node_array[xt_node_array.size() - 1]) &&
+ std::dynamic_pointer_cast<op::Constant>(
+ xt_node_array[xt_node_array.size() - 1]->get_argument(0)))
+ // here xt is determined to be the hidden (recurrent) input data and so ht is the feedforward input
+ {
+ // concatenate the sequence inputs for a given layer
+ std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{ht_1};
+ src_layer = compute_rnn_args(src_layer_labels, m, true);
- // concatenate the hidden (recurrent) input with the cell
- std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{xt};
- src_iter = compute_rnn_args(src_iter_labels, m);
- }
- else if (std::dynamic_pointer_cast<op::Broadcast>(
- hidden_ht_array[hidden_ht_array.size() - 1]) &&
- std::dynamic_pointer_cast<op::Constant>(
- hidden_ht_array[hidden_ht_array.size() - 1]->get_argument(0)))
- // here ht is determined to be the hidden (recurrent) input data and so xt is the feedforward input
- {
- std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{xt};
- src_layer = compute_rnn_args(src_layer_labels, m, true);
+ // concatenate the hidden (recurrent) input with the cell
+ std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{xt};
+ src_iter = compute_rnn_args(src_iter_labels, m);
+ }
+ else if (std::dynamic_pointer_cast<op::Broadcast>(
+ hidden_ht_array[hidden_ht_array.size() - 1]) &&
+ std::dynamic_pointer_cast<op::Constant>(
+ hidden_ht_array[hidden_ht_array.size() - 1]->get_argument(0)))
+ // here ht is determined to be the hidden (recurrent) input data and so xt is the feedforward input
+ {
+ std::vector<std::shared_ptr<pattern::op::Label>> src_layer_labels{xt};
+ src_layer = compute_rnn_args(src_layer_labels, m, true);
- std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{ht_1};
- src_iter = compute_rnn_args(src_iter_labels, m);
- }
- else
- {
- // dont fuse, if the PM didn't discover all the cells belonging to RNN layer.
- // we dont want to throw an assertion, if pattern matcher cannot discover all
- // nodes belonging to RNN, instead we will return and can compute LSTM cell wise
- return false;
- }
+ std::vector<std::shared_ptr<pattern::op::Label>> src_iter_labels{ht_1};
+ src_iter = compute_rnn_args(src_iter_labels, m);
+ }
+ else
+ {
+ // dont fuse, if the PM didn't discover all the cells belonging to RNN layer.
+ // we dont want to throw an assertion, if pattern matcher cannot discover all
+ // nodes belonging to RNN, instead we will return and can compute LSTM cell wise
+ return false;
+ }
- std::vector<std::shared_ptr<pattern::op::Label>> weights_layer_labels{weights_i2h};
- auto weights_layer = compute_rnn_args(weights_layer_labels, m);
- std::vector<std::shared_ptr<pattern::op::Label>> weights_iter_labels{weights_h2h};
- auto weights_iter = compute_rnn_args(weights_iter_labels, m);
-
- std::vector<std::shared_ptr<pattern::op::Label>> bias_layer_labels{bias_i2h};
- auto bias_layer = compute_rnn_args(bias_layer_labels, m);
- std::vector<std::shared_ptr<pattern::op::Label>> bias_iter_labels{bias_h2h};
- auto bias_iter = compute_rnn_args(bias_iter_labels, m);
-
- std::vector<std::shared_ptr<pattern::op::Label>> state_iter_labels{rpattern_ct_1};
- auto state_iter = compute_rnn_args(state_iter_labels, m);
-
- auto num_of_lstm_matched = m.get_number_of_recurrent_matches();
- size_t num_gates_in_lstm = 4;
- // TODO: assert for batch_size, sequence length and num_of_lstm's fused
- size_t batch_size = src_layer->get_shape()[0] / num_of_lstm_matched;
- size_t sequence_len = num_of_lstm_matched;
- size_t src_layer_feature_size = src_layer->get_shape()[1];
- size_t feature_size = ht_1_label[0]->get_shape()[1];
- // number of states for LSTM is 2
- size_t direction = 1;
- size_t num_fused_rnn_layers = 1;
+ std::vector<std::shared_ptr<pattern::op::Label>> params_labels{params_label};
+ auto params = compute_rnn_args(params_labels, m);
- NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
- NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
- NGRAPH_DEBUG << "weights_layer: " << join(weights_layer->get_shape());
- NGRAPH_DEBUG << "weights_iter: " << join(weights_iter->get_shape());
- NGRAPH_DEBUG << "bias_layer: " << join(bias_layer->get_shape());
- NGRAPH_DEBUG << "bias_iter: " << join(bias_iter->get_shape());
- NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
- NGRAPH_DEBUG << "batch_size: " << batch_size;
- NGRAPH_DEBUG << "feature_size: " << feature_size;
+ std::vector<std::shared_ptr<pattern::op::Label>> state_iter_labels{rpattern_ct_1};
+ auto state_iter = compute_rnn_args(state_iter_labels, m);
+
+ auto num_of_lstm_matched = m.get_number_of_recurrent_matches();
+ if (num_of_lstm_matched <= 1)
+ {
+ return false;
+ }
- NGRAPH_ASSERT(src_layer->get_arguments().size() == sequence_len ||
- std::dynamic_pointer_cast<op::Parameter>(src_layer))
- << "number of lstm inputs captured in the RNN fusion is not equal to "
- "src_sequence_length";
- NGRAPH_ASSERT(!std::dynamic_pointer_cast<op::Parameter>(src_layer) || sequence_len == 1)
- << "number of lstm inputs captured in the RNN fusion is not equal to "
- "src_sequence_length";
-
- auto src_layer_rank = src_layer->get_shape().size();
- auto src_iter_rank = src_iter->get_shape().size();
- auto weights_layer_rank = weights_layer->get_shape().size();
- auto weights_iter_rank = weights_iter->get_shape().size();
- auto bias_rank = bias_layer->get_shape().size();
- NGRAPH_ASSERT(src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
- weights_iter_rank == 2)
- << "Pattern matcher error src_layer, weights_layer, src_iter, weights_iter should "
- "have rank 2 for RNN op";
- NGRAPH_ASSERT(bias_rank == 1) << "Bias should have rank of 1 for Rnn op";
- NGRAPH_ASSERT(src_layer->get_element_type() == element::f32 &&
- src_iter->get_element_type() == element::f32)
- << "input tensor type and input recurrent state tensor type for RNN op should "
- "be float32";
+ size_t num_gates_in_lstm = 4;
+ size_t batch_size = src_layer->get_shape()[0] / num_of_lstm_matched;
+ size_t sequence_len = num_of_lstm_matched;
+ size_t src_layer_feature_size = src_layer->get_shape()[1];
+ size_t feature_size = ht_1_label[0]->get_shape()[1];
+ // number of states for LSTM is 2
+ size_t direction = 1;
+ size_t num_fused_rnn_layers = 1;
- auto rnn = std::make_shared<op::gpu::Rnn>(src_layer,
- src_iter,
- weights_layer,
- weights_iter,
- bias_layer,
- bias_iter,
- state_iter,
- num_of_lstm_matched,
- num_gates_in_lstm,
- sequence_len,
- src_layer_feature_size,
- feature_size,
- direction,
- num_fused_rnn_layers);
+ NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
+ NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
+ NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
+ NGRAPH_DEBUG << "batch_size: " << batch_size;
+ NGRAPH_DEBUG << "feature_size: " << feature_size;
- std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(num_of_lstm_matched,
- nullptr);
- auto rnn_ht_out = std::make_shared<op::GetOutputElement>(rnn, 0);
- auto rnn_ct_out = std::make_shared<op::GetOutputElement>(rnn, 1);
+ RETURN_IF_FALSE(src_layer->get_arguments().size() == sequence_len ||
+ std::dynamic_pointer_cast<op::Parameter>(src_layer),
+ "number of lstm inputs captured in the RNN fusion is not equal to "
+ "src_sequence_length");
+ RETURN_IF_FALSE(!std::dynamic_pointer_cast<op::Parameter>(src_layer) || sequence_len == 1,
+ "number of lstm inputs captured in the RNN fusion is not equal to "
+ "src_sequence_length");
+
+ auto src_layer_rank = src_layer->get_shape().size();
+ auto src_iter_rank = src_iter->get_shape().size();
+ RETURN_IF_FALSE(src_layer_rank == 2 && src_iter_rank == 2,
+ "Pattern matcher error src_layer, src_iter, have rank 2 for RNN op");
+ RETURN_IF_FALSE(src_layer->get_element_type() == element::f32 &&
+ src_iter->get_element_type() == element::f32,
+ "input tensor type and input recurrent state tensor type for RNN op should "
+ "be float32");
- //slice the rnn ht's
- size_t start_index = 0;
- size_t end_index = batch_size;
- // capture the slices in the reverse order, so it corrosponds to lstm_goes order captured by the Pattern matcher
- for (size_t i = 0; i < num_of_lstm_matched; i++)
- {
- ht_slice_per_timestep[i] = (std::make_shared<op::Slice>(
- rnn_ht_out, Coordinate{start_index, 0}, Coordinate{end_index, feature_size}));
- start_index += batch_size;
- end_index += batch_size;
- }
- std::reverse(ht_slice_per_timestep.begin(), ht_slice_per_timestep.end());
+ auto rnn = std::make_shared<op::gpu::Rnn>(src_layer,
+ src_iter,
+ params,
+ state_iter,
+ num_of_lstm_matched,
+ num_gates_in_lstm,
+ sequence_len,
+ src_layer_feature_size,
+ feature_size,
+ direction,
+ num_fused_rnn_layers);
- NGRAPH_DEBUG << "rnn_time_slice: " << ht_slice_per_timestep.size();
+ std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(num_of_lstm_matched, nullptr);
+ auto rnn_ht_out = std::make_shared<op::GetOutputElement>(rnn, 0);
+ auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 1);
+ auto layer_rnn_ct = std::make_shared<op::GetOutputElement>(rnn, 2);
+
+ //slice the rnn ht's
+ size_t start_index = 0;
+ size_t end_index = batch_size;
+ // capture the slices in the reverse order, so it corrosponds to lstm_goes order captured by the Pattern matcher
+ for (size_t i = 0; i < num_of_lstm_matched; i++)
+ {
+ ht_slice_per_timestep[i] = (std::make_shared<op::Slice>(
+ rnn_ht_out, Coordinate{start_index, 0}, Coordinate{end_index, feature_size}));
+ start_index += batch_size;
+ end_index += batch_size;
+ }
+ std::reverse(ht_slice_per_timestep.begin(), ht_slice_per_timestep.end());
- // find the lstm's nodes captured in PM
- auto lstm_goes = m.get_bound_nodes_for_pattern(lstm_node_label);
- std::vector<std::shared_ptr<ngraph::Node>> lstm_nodes;
+ NGRAPH_DEBUG << "rnn_time_slice: " << ht_slice_per_timestep.size();
- // we need to collect LSTM from GOE's, in order to determine
- // the individaual time slice output ht. lstm_goes will hold the GOE in the decreasing
- // order of the time slices
- for (size_t i = 0; i < lstm_goes.size(); i++)
- {
- // lstm's will be the input to GOE's
- lstm_nodes.push_back(lstm_goes[i]->get_arguments()[0]);
- }
+ // find the lstm's nodes captured in PM
+ auto lstm_goes = m.get_bound_nodes_for_pattern(lstm_node_label);
+ std::vector<std::shared_ptr<ngraph::Node>> lstm_nodes;
- NGRAPH_ASSERT(sequence_len == lstm_nodes.size())
- << " Number of lstm nodes in RNN layer is not equal to time slices";
- NGRAPH_ASSERT(lstm_nodes.size() == lstm_goes.size() ||
- lstm_goes.size() == ht_slice_per_timestep.size())
- << "Number of slices of rnn output ht is not equal to the time slices in RNN layer";
-
- // collect all the consumers of LSTM goe's (ht)
- std::set<std::shared_ptr<ngraph::Node>> lstm_goe0_user;
- std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> map_goe_to_lstm_slices;
- std::shared_ptr<Node> goe_0;
- for (size_t index = 0; index < lstm_nodes.size(); index++)
+ // we need to collect LSTM from GOE's, in order to determine
+ // the individaual time slice output ht. lstm_goes will hold the GOE in the decreasing
+ // order of the time slices
+ for (size_t i = 0; i < lstm_goes.size(); i++)
+ {
+ // lstm's will be the input to GOE's
+ lstm_nodes.push_back(lstm_goes[i]->get_arguments()[0]);
+ }
+
+ RETURN_IF_FALSE(sequence_len == lstm_nodes.size(),
+ " Number of lstm nodes in RNN layer is not equal to time slices");
+ RETURN_IF_FALSE(
+ lstm_nodes.size() == lstm_goes.size() ||
+ lstm_goes.size() == ht_slice_per_timestep.size(),
+ "Number of slices of rnn output ht is not equal to the time slices in RNN layer");
+
+ // collect all the consumers of LSTM goe's (ht)
+ std::set<std::shared_ptr<ngraph::Node>> lstm_goe0_user;
+ std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> map_goe_to_lstm_slices;
+ std::shared_ptr<Node> goe_0;
+ for (size_t index = 0; index < lstm_nodes.size(); index++)
+ {
+ // now get the GOE0 which is the first output of lstm (ht)
+ for (auto& goes : lstm_nodes[index]->get_outputs().at(0).get_inputs())
{
- // now get the GOE0 which is the first output of lstm (ht)
- for (auto& goes : lstm_nodes[index]->get_outputs().at(0).get_inputs())
+ auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
+ // first output node of lstm
+ if (goe_node->get_n() == 0)
{
- auto goe_node =
- std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
- // first output node of lstm
- if (goe_node->get_n() == 0)
+ goe_0 = goes->get_node();
+ for (auto goe0_user : goe_0->get_users())
{
- goe_0 = goes->get_node();
- for (auto goe0_user : goe_0->get_users())
+ if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
+ lstm_nodes.end() &&
+ ngraph::is_used(goe0_user.get()))
{
- if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
- lstm_nodes.end() &&
- ngraph::is_used(goe0_user.get()))
- {
- lstm_goe0_user.insert(goe0_user);
- map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
- NGRAPH_DEBUG
- << "ht_slice: " << ht_slice_per_timestep[index]->get_name()
- << " goe0_user " << goe0_user->get_name() << " ";
- }
+ lstm_goe0_user.insert(goe0_user);
+ map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
+ NGRAPH_DEBUG << "ht_slice: " << ht_slice_per_timestep[index]->get_name()
+ << " goe0_user " << goe0_user->get_name() << " ";
}
}
- // we need to only check the last LSTM cell Ct user and replace if needed.
- if ((index == 0) && (goe_node->get_n() == 1))
- {
- // check if the last LSTM cell has any consumers
- auto n_time_step_lstm_ct_goe = goes->get_node();
- ngraph::replace_node(n_time_step_lstm_ct_goe, rnn_ct_out);
- }
+ }
+ // we need to only check the last LSTM cell Ct user and replace if needed.
+ if ((index == 0) && (goe_node->get_n() == 1))
+ {
+ // check if the last LSTM cell has any consumers
+ auto n_time_step_lstm_ct_goe = goes->get_node();
+ ngraph::replace_node(n_time_step_lstm_ct_goe, layer_rnn_ct);
}
}
+ }
- //now go through the lstm goe_0 consumers and replace them with the slice
- for (auto& node : lstm_goe0_user)
+ //now go through the lstm goe_0 consumers and replace them with the slice
+ for (auto& node : lstm_goe0_user)
+ {
+ for (size_t i = 0; i < node->get_input_size(); i++)
{
- for (size_t i = 0; i < node->get_input_size(); i++)
+ if (map_goe_to_lstm_slices.find(node->get_argument(i)) !=
+ map_goe_to_lstm_slices.end())
{
- if (map_goe_to_lstm_slices.find(node->get_argument(i)) !=
- map_goe_to_lstm_slices.end())
- {
- node->get_inputs().at(i).replace_output(
- map_goe_to_lstm_slices[node->get_argument(i)]->get_outputs().at(0));
- }
+ node->get_inputs().at(i).replace_output(
+ map_goe_to_lstm_slices[node->get_argument(i)]->get_outputs().at(0));
}
}
+ }
- NGRAPH_DEBUG << "End of recurrent fusion call back "
- << "matched_node: " << m.get_match_root()->get_name();
- return true;
+ NGRAPH_DEBUG << "End of recurrent fusion call back "
+ << "matched_node: " << m.get_match_root()->get_name();
+ return true;
- };
+ };
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
@@ -558,6 +616,37 @@ static std::shared_ptr<Node>
return std::make_shared<op::Concat>(node_labels, 0);
}
+static std::shared_ptr<Node>
+ compute_multi_layer_rnn_params(const std::shared_ptr<pattern::op::Label>& param_label,
+ pattern::RecurrentMatcher& m)
+{
+ auto param_nodes = m.get_bound_nodes_for_pattern(param_label);
+ std::reverse(param_nodes.begin(), param_nodes.end());
+
+ // iterate over params for each layer in order [layer 0, ... layer n]
+ NodeVector biases;
+ NodeVector layer_params;
+ for (auto& param : param_nodes)
+ {
+ // split and group layer weights and layer biases
+ auto const& args = param->get_arguments();
+ for (size_t i = 0; i < args.size(); i++)
+ {
+ // first half set of params are weights, second half are biases
+ if (i < (args.size() / 2))
+ {
+ layer_params.push_back(args[i]);
+ }
+ else
+ {
+ biases.push_back(args[i]);
+ }
+ }
+ }
+ layer_params.insert(layer_params.end(), biases.begin(), biases.end());
+ return std::make_shared<op::Concat>(layer_params, 0);
+}
+
void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
{
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100});
@@ -566,10 +655,8 @@ void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
std::make_shared<pattern::op::Skip>(src_layer_label, pattern::has_class<op::Slice>());
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
- auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
- auto weights_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
- auto bias_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
- auto bias_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
+ auto params_label = std::make_shared<pattern::op::Label>(
+ element::f32, Shape{400 * 100 + 400 * 100 + 400 + 400});
auto state_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
size_t ref_number_of_timesteps = 3;
@@ -582,10 +669,7 @@ void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto ref_rnn_node = std::make_shared<op::gpu::Rnn>(src_slice,
src_iter_label,
- weights_layer_label,
- weights_iter_label,
- bias_layer_label,
- bias_iter_label,
+ params_label,
state_iter_label,
ref_number_of_timesteps,
ref_number_of_gates_per_cell,
@@ -599,188 +683,172 @@ void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto rnn_ht_label =
std::make_shared<pattern::op::Label>(rnn_ht_out, nullptr, NodeVector{rnn_ht_out});
- pattern::recurrent_graph_rewrite_callback callback = [src_layer_label,
- src_iter_label,
- weights_layer_label,
- weights_iter_label,
- bias_layer_label,
- bias_iter_label,
- state_iter_label,
- rnn_ht_label](
- pattern::RecurrentMatcher& m) {
-
- if (m.get_number_of_recurrent_matches() <= 1)
- {
- return false;
- }
-
- auto src_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
- auto rnn_ht_out_nodes = m.get_bound_nodes_for_pattern(rnn_ht_label);
- auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches();
- NGRAPH_DEBUG << "In Recurrent multi layer RNN fusion callback ";
- NGRAPH_DEBUG << "Number of RNN's Matched: " << number_of_rnn_cell_matched;
- NGRAPH_DEBUG << "matched_root: " << m.get_match_root()->get_name();
- NGRAPH_DEBUG << "src_layer_node: " << src_nodes[0]->get_name();
+ pattern::recurrent_graph_rewrite_callback callback =
+ [src_layer_label, src_iter_label, params_label, state_iter_label, rnn_ht_label](
+ pattern::RecurrentMatcher& m) {
- // we can fuse across different RNN layers only if SLC == DLC
- for (size_t i = 0; i < number_of_rnn_cell_matched; i++)
- {
- if (src_nodes[i]->get_shape()[1] != rnn_ht_out_nodes[i]->get_shape()[1])
+ if (m.get_number_of_recurrent_matches() <= 1)
{
- NGRAPH_DEBUG << "Not fusing since the feature sizes for xt and ht_1 dont match";
return false;
}
- }
- // we just need to capture the input symbols {x0 | x1.....| xt} of the first lstm layer
- // the intermediate inputs for the next layer will be computed by the kernel
- auto src_layer_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
- auto src_layer = src_layer_nodes[src_layer_nodes.size() - 1];
-
- auto src_iter = compute_multi_layer_rnn_inputs(src_iter_label, m);
- auto state_iter = compute_multi_layer_rnn_inputs(state_iter_label, m);
- auto weights_layer = compute_multi_layer_rnn_inputs(weights_layer_label, m);
- auto weights_iter = compute_multi_layer_rnn_inputs(weights_iter_label, m);
- auto bias_layer = compute_multi_layer_rnn_inputs(bias_layer_label, m);
- auto bias_iter = compute_multi_layer_rnn_inputs(bias_iter_label, m);
-
- // collect list of rnn ops (layers)
- std::vector<std::shared_ptr<op::gpu::Rnn>> rnn_nodes;
- for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label))
- {
- auto rnn_op =
- std::dynamic_pointer_cast<op::gpu::Rnn>(rnn_goe_input->get_arguments()[0]);
- if (rnn_op)
+ auto src_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
+ auto rnn_ht_out_nodes = m.get_bound_nodes_for_pattern(rnn_ht_label);
+ auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches();
+ NGRAPH_DEBUG << "In Recurrent multi layer RNN fusion callback ";
+ NGRAPH_DEBUG << "Number of RNN's Matched: " << number_of_rnn_cell_matched;
+ NGRAPH_DEBUG << "matched_root: " << m.get_match_root()->get_name();
+ NGRAPH_DEBUG << "src_layer_node: " << src_nodes[0]->get_name();
+
+ // we can fuse across different RNN layers only if SLC == DLC
+ for (size_t i = 0; i < number_of_rnn_cell_matched; i++)
{
- rnn_nodes.push_back(rnn_op);
+ if (src_nodes[i]->get_shape()[1] != rnn_ht_out_nodes[i]->get_shape()[1])
+ {
+ NGRAPH_DEBUG << "Not fusing since the feature sizes for xt and ht_1 dont match";
+ return false;
+ }
}
- else
+
+ // we just need to capture the input symbols {x0 | x1.....| xt} of the first lstm layer
+ // the intermediate inputs for the next layer will be computed by the kernel
+ auto src_layer_nodes = m.get_bound_nodes_for_pattern(src_layer_label);
+ auto src_layer = src_layer_nodes[src_layer_nodes.size() - 1];
+
+ auto src_iter = compute_multi_layer_rnn_inputs(src_iter_label, m);
+ auto state_iter = compute_multi_layer_rnn_inputs(state_iter_label, m);
+ auto params = compute_multi_layer_rnn_params(params_label, m);
+
+ // collect list of rnn ops (layers)
+ std::vector<std::shared_ptr<op::gpu::Rnn>> rnn_nodes;
+ for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label))
{
- throw ngraph_error("Input for RNN output GetOuputElement Op should be RNN");
+ auto rnn_op =
+ std::dynamic_pointer_cast<op::gpu::Rnn>(rnn_goe_input->get_arguments()[0]);
+ if (rnn_op)
+ {
+ rnn_nodes.push_back(rnn_op);
+ }
+ else
+ {
+ throw ngraph_error("Input for RNN output GetOuputElement Op should be RNN");
+ }
}
- }
- size_t num_time_steps = rnn_nodes[0]->get_num_timesteps();
- size_t num_gates_in_lstm = rnn_nodes[0]->get_gates_per_cell();
- size_t batch_size = rnn_nodes[0]->get_batch_size();
- size_t sequence_len = rnn_nodes[0]->get_src_sequence_length();
- size_t src_layer_feature_size = rnn_nodes[0]->get_src_layer_feature_size();
- size_t feature_size = rnn_nodes[0]->get_src_iter_feature_size();
- size_t rnn_direction = rnn_nodes[0]->get_direction();
- size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches();
+ size_t num_time_steps = rnn_nodes[0]->get_num_timesteps();
+ size_t num_gates_in_lstm = rnn_nodes[0]->get_gates_per_cell();
+ size_t batch_size = rnn_nodes[0]->get_batch_size();
+ size_t sequence_len = rnn_nodes[0]->get_src_sequence_length();
+ size_t src_layer_feature_size = rnn_nodes[0]->get_src_layer_feature_size();
+ size_t feature_size = rnn_nodes[0]->get_src_iter_feature_size();
+ size_t rnn_direction = rnn_nodes[0]->get_direction();
+ size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches();
- NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
- NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
- NGRAPH_DEBUG << "state_iter: " << join(state_iter->get_shape());
- NGRAPH_DEBUG << "weights_layer: " << join(weights_layer->get_shape());
- NGRAPH_DEBUG << "weights_iter: " << join(weights_iter->get_shape());
- NGRAPH_DEBUG << "bias_layer: " << join(bias_layer->get_shape());
- NGRAPH_DEBUG << "bias_iter: " << join(bias_iter->get_shape());
- NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
- NGRAPH_DEBUG << "batch_size: " << batch_size;
- NGRAPH_DEBUG << "feature_size: " << feature_size;
+ NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
+ NGRAPH_DEBUG << "src_iter: " << join(src_iter->get_shape());
+ NGRAPH_DEBUG << "state_iter: " << join(state_iter->get_shape());
+ NGRAPH_DEBUG << "params size {wx|wh|bx|bh}: " << shape_size(params->get_shape());
+ NGRAPH_DEBUG << "src_seq_len: " << sequence_len;
+ NGRAPH_DEBUG << "batch_size: " << batch_size;
+ NGRAPH_DEBUG << "feature_size: " << feature_size;
- if (auto src_rnn = std::dynamic_pointer_cast<op::gpu::Rnn>(src_layer))
- {
- NGRAPH_ASSERT(src_rnn->get_num_timesteps() == num_time_steps)
- << "input symbols for the layer fused RNN op, should be captured only for the first layer";
- }
+ if (auto src_rnn = std::dynamic_pointer_cast<op::gpu::Rnn>(src_layer))
+ {
+ RETURN_IF_FALSE(
+ src_rnn->get_num_timesteps() == num_time_steps,
+ "input symbols for the layer fused RNN op, should be captured only for the "
+ "first layer");
+ }
- NGRAPH_ASSERT(!std::dynamic_pointer_cast<op::Parameter>(src_layer) ||
- rnn_nodes[0]->get_num_timesteps() == 1)
- << "input symbols for the layer fused RNN op, should be captured only for the first "
- "layer";
- NGRAPH_ASSERT((src_iter->get_arguments().size()) == num_fused_rnn_layers)
- << "number of hidden states for RNN op in the layer fusion is not equal to num of "
- "fused_rnn_layers";
- NGRAPH_ASSERT((state_iter->get_arguments().size()) == num_fused_rnn_layers)
- << "number of cell states for RNN op in the layer fusion is not equal to num of "
- "fused_rnn_layers";
- NGRAPH_ASSERT((weights_layer->get_arguments().size()) == num_fused_rnn_layers)
- << "weights w.r.to input symbols of RNN op in the layer fusion is not equal to num of "
- "fused_rnn_layers";
- NGRAPH_ASSERT((weights_iter->get_arguments().size()) == num_fused_rnn_layers)
- << "weights w.r.to cell states of RNN op in the layer fusion is not equal to num of "
- "fused_rnn_layers";
- NGRAPH_ASSERT((bias_layer->get_arguments().size()) == num_fused_rnn_layers)
- << "input bias of RNN op in the layer fusion is not equal to num of fused_rnn_layers";
- NGRAPH_ASSERT((bias_iter->get_arguments().size()) == num_fused_rnn_layers)
- << "recurrent bias of RNN op in the layer fusion is not equal to num of "
- "fused_rnn_layers";
+ RETURN_IF_FALSE(
+ !std::dynamic_pointer_cast<op::Parameter>(src_layer) ||
+ rnn_nodes[0]->get_num_timesteps() == 1,
+ "input symbols for the layer fused RNN op, should be captured only for the first "
+ "layer");
+ RETURN_IF_FALSE(
+ (src_iter->get_arguments().size()) == num_fused_rnn_layers,
+ "number of hidden states for RNN op in the layer fusion is not equal to num of "
+ "fused_rnn_layers");
+ RETURN_IF_FALSE(
+ (state_iter->get_arguments().size()) == num_fused_rnn_layers,
+ "number of cell states for RNN op in the layer fusion is not equal to num of "
+ "fused_rnn_layers");
+ RETURN_IF_FALSE(
+ (params->get_arguments().size()) == num_fused_rnn_layers * 4,
+ "RNN param tensor does not consist of normal and recurrent weight and bias tensor "
+ "for each layer");
- auto rnn = std::make_shared<op::gpu::Rnn>(src_layer,
- src_iter,
- weights_layer,
- weights_iter,
- bias_layer,
- bias_iter,
- state_iter,
- num_time_steps,
- num_gates_in_lstm,
- sequence_len,
- src_layer_feature_size,
- feature_size,
- rnn_direction,
- num_fused_rnn_layers);
+ auto rnn = std::make_shared<op::gpu::Rnn>(src_layer,
+ src_iter,
+ params,
+ state_iter,
+ num_time_steps,
+ num_gates_in_lstm,
+ sequence_len,
+ src_layer_feature_size,
+ feature_size,
+ rnn_direction,
+ num_fused_rnn_layers);
- auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
- // auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 1);
- auto layer_rnn_ct = std::make_shared<op::GetOutputElement>(rnn, 2);
+ auto output_layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
+ auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 1);
+ auto layer_rnn_ct = std::make_shared<op::GetOutputElement>(rnn, 2);
- // Replace all the users of RNN cell state {ct} across different user.
- auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node>& rnn_ct, size_t layer) {
- std::shared_ptr<Node> node_to_replace = rnn_ct;
- auto ct_slice = std::make_shared<op::Slice>(
- layer_rnn_ct,
- Coordinate{static_cast<unsigned long>(batch_size * (layer - 1)), 0},
- Coordinate{static_cast<unsigned long>(batch_size * rnn_direction * layer),
- static_cast<unsigned long>(feature_size)});
+ // Replace all the users of RNN cell state {ct} across different user.
+ auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node>& rnn_ct, size_t layer) {
+ std::shared_ptr<Node> node_to_replace = rnn_ct;
+ auto ct_slice = std::make_shared<op::Slice>(
+ layer_rnn_ct,
+ Coordinate{static_cast<unsigned long>(batch_size * (layer - 1)), 0},
+ Coordinate{static_cast<unsigned long>(batch_size * rnn_direction * layer),
+ static_cast<unsigned long>(feature_size)});
- if (rnn_ct->get_users().size() == 1)
- {
- if (std::dynamic_pointer_cast<op::Slice>(rnn_ct->get_users()[0]))
+ if (rnn_ct->get_users().size() == 1)
{
- node_to_replace = rnn_ct->get_users()[0];
+ if (std::dynamic_pointer_cast<op::Slice>(rnn_ct->get_users()[0]))
+ {
+ node_to_replace = rnn_ct->get_users()[0];
+ }
}
- }
- if (ngraph::is_used(node_to_replace.get()))
- {
- ngraph::replace_node(node_to_replace, ct_slice);
- }
- };
-
- for (size_t index = 0; index < rnn_nodes.size(); index++)
- {
- for (auto& rnn_goes : rnn_nodes[index]->get_users())
- {
- NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
- if (rnn_goes->get_users().empty())
+ if (ngraph::is_used(node_to_replace.get()))
{
- continue;
+ ngraph::replace_node(node_to_replace, ct_slice);
}
+ };
- if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
+ for (size_t index = 0; index < rnn_nodes.size(); index++)
+ {
+ for (auto& rnn_goes : rnn_nodes[index]->get_users())
{
- // we need to only replace the {ht} consumers of the last RNN layer,
- // since for other layers the intermediate outputs {ht} will be computed
- // within the kernel
- if (index == 0)
+ NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
+ if (rnn_goes->get_users().empty())
{
- if (rnn_goe_node->get_n() == 0)
- {
- ngraph::replace_node(rnn_goes, layer_rnn_ht);
- }
+ continue;
}
- if (rnn_goe_node->get_n() == 1)
+
+ if (auto rnn_goe_node =
+ std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{
- replace_rnn_output_cellstate(rnn_goes, num_fused_rnn_layers - index);
+ // we need to only replace the {ht} consumers of the last RNN layer,
+ // since for other layers the intermediate outputs {ht} will be computed
+ // within the kernel
+ if (index == 0)
+ {
+ if (rnn_goe_node->get_n() == 0)
+ {
+ ngraph::replace_node(rnn_goes, output_layer_rnn_ht);
+ }
+ }
+ if (rnn_goe_node->get_n() == 2)
+ {
+ replace_rnn_output_cellstate(rnn_goes, num_fused_rnn_layers - index);
+ }
}
}
}
- }
- return true;
- };
+ return true;
+ };
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
diff --git a/test/gpu_fusion.cpp b/test/gpu_fusion.cpp
index c0620e64..dcaa0a5b 100644
--- a/test/gpu_fusion.cpp
+++ b/test/gpu_fusion.cpp
@@ -13,9 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
-
#include <algorithm>
#include <cstdio>
+#include <cudnn.h>
#include <iostream>
#include <list>
#include <memory>
@@ -45,7 +45,6 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
-#include "ngraph/runtime/gpu/op/lstm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/serializer.hpp"
@@ -62,15 +61,14 @@
using namespace ngraph;
using namespace std;
+#if CUDNN_VERSION >= 7200
TEST(gpu_fusion, rnn_fprop_1_lstm_cell)
{
auto src_layer = make_shared<op::Parameter>(element::f32, Shape{10, 100});
auto src_iter = make_shared<op::Parameter>(element::f32, Shape{10, 100});
+ auto params =
+ make_shared<op::Parameter>(element::f32, Shape{400 * 100 + 400 * 100 + 400 + 400});
auto state_iter = make_shared<op::Parameter>(element::f32, Shape{10, 100});
- auto weights_layer = make_shared<op::Parameter>(element::f32, Shape{400, 100});
- auto weights_iter = make_shared<op::Parameter>(element::f32, Shape{400, 100});
- auto bias_layer = make_shared<op::Parameter>(element::f32, Shape{400});
- auto bias_iter = make_shared<op::Parameter>(element::f32, Shape{400});
const int number_of_timesteps = 1;
const int number_of_gates_per_cell = 4;
@@ -81,10 +79,7 @@ TEST(gpu_fusion, rnn_fprop_1_lstm_cell)
const int num_of_rnn_fused_layer = 1;
auto rnn_node = make_shared<op::gpu::Rnn>(src_layer,
src_iter,
- weights_layer,
- weights_iter,
- bias_layer,
- bias_iter,
+ params,
state_iter,
number_of_timesteps,
number_of_gates_per_cell,
@@ -96,10 +91,8 @@ TEST(gpu_fusion, rnn_fprop_1_lstm_cell)
auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0);
auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1);
- auto func = make_shared<Function>(
- NodeVector{rnn_ht_output, rnn_ct_output},
- op::ParameterVector{
- src_layer, src_iter, weights_layer, weights_iter, bias_layer, bias_iter, state_iter});
+ auto func = make_shared<Function>(NodeVector{rnn_ht_output, rnn_ct_output},
+ op::ParameterVector{src_layer, src_iter, params, state_iter});
auto backend = runtime::Backend::create("GPU");
shared_ptr<runtime::TensorView> src_layer_t =
@@ -108,14 +101,9 @@ TEST(gpu_fusion, rnn_fprop_1_lstm_cell)
backend->create_tensor(element::f32, src_iter->get_shape());
shared_ptr<runtime::TensorView> state_iter_t =
backend->create_tensor(element::f32, state_iter->get_shape());
- shared_ptr<runtime::TensorView> weights_layer_t =
- backend->create_tensor(element::f32, weights_layer->get_shape());
- shared_ptr<runtime::TensorView> weights_iter_t =
- backend->create_tensor(element::f32, weights_iter->get_shape());
- shared_ptr<runtime::TensorView> bias_layer_t =
- backend->create_tensor(element::f32, bias_layer->get_shape());
- shared_ptr<runtime::TensorView> bias_iter_t =
- backend->create_tensor(element::f32, bias_iter->get_shape());
+ shared_ptr<runtime::TensorView> params_t =
+ backend->create_tensor(element::f32, params->get_shape());
+
shared_ptr<runtime::TensorView> result_ht = backend->create_tensor(element::f32, {10, 100});
shared_ptr<runtime::TensorView> result_ct =
backend->create_tensor(element::f32, Shape{10, 100});
@@ -123,37 +111,21 @@ TEST(gpu_fusion, rnn_fprop_1_lstm_cell)
copy_data(src_layer_t, vector<float>(1000, 1));
copy_data(src_iter_t, vector<float>(1000, 1));
copy_data(state_iter_t, vector<float>(1000, 1));
- copy_data(weights_layer_t, vector<float>(400 * 100, 1));
- copy_data(weights_iter_t, vector<float>(400 * 100, 1));
- copy_data(bias_layer_t, vector<float>(400, 1));
- copy_data(bias_iter_t, vector<float>(400, 1));
+ copy_data(params_t, vector<float>(shape_size(params->get_shape()), 1));
- backend->call_with_validate(func,
- {result_ht, result_ct},
- {src_layer_t,
- src_iter_t,
- weights_layer_t,
- weights_iter_t,
- bias_layer_t,
- bias_iter_t,
- state_iter_t});
+ backend->call_with_validate(
+ func, {result_ht, result_ct}, {src_layer_t, src_iter_t, params_t, state_iter_t});
vector<float> expected_ht(10 * 100, 0.964028f);
vector<float> expected_ct;
for (size_t i = 0; i < 10 * 100; i++)
{
- if (i < 1000)
- {
- expected_ct.push_back(0.964028f);
- }
- else
- {
- expected_ct.push_back(2.0f);
- }
+ expected_ct.push_back(0.964028f);
}
EXPECT_TRUE(test::all_close(expected_ht, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(expected_ct, read_vector<float>(result_ct)));
}
+#endif
TEST(gpu_fusion, fuse_lstm_cells)
{
@@ -165,7 +137,7 @@ TEST(gpu_fusion, fuse_lstm_cells)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
- auto lstm_ops = get_ops_of_type<op::gpu::Lstm>(func);
+ auto lstm_ops = get_ops_of_type<op::gpu::Rnn>(func);
EXPECT_EQ(lstm_ops.size(), 6);
}
@@ -219,7 +191,7 @@ static std::shared_ptr<Function> make_function(const std::string& file_name)
return func;
}
-TEST(gpu_fusion, lstm_result)
+TEST(gpu_fusion, lstm_analytic)
{
auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1});
auto weights_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
@@ -248,7 +220,7 @@ TEST(gpu_fusion, lstm_result)
auto forget_gate = std::make_shared<op::Sigmoid>(input_slice_0);
//ct-1 -> cell state
- auto c_const = op::Constant::create(element::f32, Shape{}, {1.0});
+ auto c_const = op::Constant::create(element::f32, Shape{}, {-1.0});
auto ct_1 = std::make_shared<op::Broadcast>(c_const, Shape{1, 1}, AxisSet{0, 1});
//auto ct_1 = std::make_shared<op::>(element::f32, Shape{10, 100});
auto multiply_forget_gate_ct_1 = std::make_shared<op::Multiply>(forget_gate, ct_1);
@@ -280,19 +252,19 @@ TEST(gpu_fusion, lstm_result)
std::shared_ptr<runtime::TensorView> weights_i2h_t =
backend->create_tensor(element::f32, weights_i2h->get_shape());
- copy_data(weights_i2h_t, std::vector<float>{1.0, 1.0, 1.0, 1.0});
+ copy_data(weights_i2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::TensorView> weights_h2h_t =
backend->create_tensor(element::f32, weights_h2h->get_shape());
- copy_data(weights_h2h_t, std::vector<float>{1.0, 1.0, 1.0, 1.0});
+ copy_data(weights_h2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::TensorView> bias_i2h_t =
backend->create_tensor(element::f32, bias_i2h->get_shape());
- copy_data(bias_i2h_t, std::vector<float>{1.0, 1.0, 1.0, 1.0});
+ copy_data(bias_i2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::TensorView> bias_h2h_t =
backend->create_tensor(element::f32, bias_h2h->get_shape());
- copy_data(bias_h2h_t, std::vector<float>{1.0, 1.0, 1.0, 1.0});
+ copy_data(bias_h2h_t, std::vector<float>{-1.0, -1.0, -1.0, -1.0});
std::shared_ptr<runtime::TensorView> result_ht =
backend->create_tensor(element::f32, ht->get_shape());
@@ -304,19 +276,162 @@ TEST(gpu_fusion, lstm_result)
{input_xt_t, weights_i2h_t, weights_h2h_t, bias_i2h_t, bias_h2h_t});
auto sig = [](float x) { return 1.0f / (1.0f + std::exp(-x)); };
- float ct_val = sig(4) + sig(4) * std::tanh(4);
- float ht_val = sig(4) * std::tanh(ct_val);
+ float ct_val = -sig(-4.0f) + sig(-4.0f) * std::tanh(-4.0f);
+ float ht_val = sig(-4.0f) * std::tanh(ct_val);
EXPECT_TRUE(test::all_close(std::vector<float>{ht_val}, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(std::vector<float>{ct_val}, read_vector<float>(result_ct)));
}
+TEST(gpu_fusion, fuse_2_layer_rnn_1lstm_analytic)
+{
+ auto input_xt = std::make_shared<op::Parameter>(element::f32, Shape{1, 1});
+ auto weights_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
+ auto weights_i2h_reshape =
+ std::make_shared<op::Reshape>(weights_i2h, AxisVector{1, 0}, Shape{1, 4});
+ auto dot_1 = std::make_shared<op::Dot>(input_xt, weights_i2h_reshape);
+
+ auto bias_i2h = std::make_shared<op::Parameter>(element::f32, Shape{4});
+ auto broadcast_bias_i2h = std::make_shared<op::Broadcast>(bias_i2h, Shape{1, 4}, AxisSet{0});
+ auto add_1 = std::make_shared<op::Add>(dot_1, broadcast_bias_i2h);
+
+ auto h_const = op::Constant::create(element::f32, Shape{}, {1.0});
+ auto hidden_ht = std::make_shared<op::Broadcast>(h_const, Shape{1, 1}, AxisSet{0, 1});
+ auto weights_h2h = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
+ auto param2_2_reshape =
+ std::make_shared<op::Reshape>(weights_h2h, AxisVector{1, 0}, Shape{1, 4});
+ auto dot_2 = std::make_shared<op::Dot>(hidden_ht, param2_2_reshape);
+
+ auto bias_h2h = std::make_shared<op::Parameter>(element::f32, Shape{4});
+ auto broadcast_bias_h2h = std::make_shared<op::Broadcast>(bias_h2h, Shape{1, 4}, AxisSet{0});
+ auto add_2 = std::make_shared<op::Add>(dot_2, broadcast_bias_h2h);
+
+ auto X = std::make_shared<op::Add>(add_2, add_1);
+ // construct forget gate
+ auto input_slice_0 = std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{1, 1});
+ auto forget_gate = std::make_shared<op::Sigmoid>(input_slice_0);
+
+ //ct-1 -> cell state
+ auto c_const = op::Constant::create(element::f32, Shape{}, {1.0});
+ auto ct_1 = std::make_shared<op::Broadcast>(c_const, Shape{1, 1}, AxisSet{0, 1});
+ //auto ct_1 = std::make_shared<op::>(element::f32, Shape{10, 100});
+ auto multiply_forget_gate_ct_1 = std::make_shared<op::Multiply>(forget_gate, ct_1);
+
+ // construct input gate
+ auto input_slice_1 = std::make_shared<op::Slice>(X, Coordinate{0, 1}, Coordinate{1, 2});
+ auto input_gate = std::make_shared<op::Sigmoid>(input_slice_1);
+ auto input_slice_2 = std::make_shared<op::Slice>(X, Coordinate{0, 2}, Coordinate{1, 3});
+ auto tanh_1 = std::make_shared<op::Tanh>(input_slice_2);
+ auto multiply_input_gate_tanh_1 = std::make_shared<op::Multiply>(input_gate, tanh_1);
+
+ auto ct = std::make_shared<op::Add>(multiply_forget_gate_ct_1, multiply_input_gate_tanh_1);
+
+ // construct output gate
+ auto input_slice_3 = std::make_shared<op::Slice>(X, Coordinate{0, 3}, Coordinate{1, 4});
+ auto output_gate = std::make_shared<op::Sigmoid>(input_slice_3);
+ auto tanh_2 = std::make_shared<op::Tanh>(ct);
+ auto ht = std::make_shared<op::Multiply>(output_gate, tanh_2);
+
+ // next lstm layer
+ auto weights_i2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
+ auto weights_i2h_0_reshape_0 =
+ std::make_shared<op::Reshape>(weights_i2h_0, AxisVector{1, 0}, Shape{1, 4});
+ auto dot_1_0 = std::make_shared<op::Dot>(ht, weights_i2h_0_reshape_0);
+
+ auto bias_i2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4});
+ auto broadcast_bias_i2h_0_0 =
+ std::make_shared<op::Broadcast>(bias_i2h_0, Shape{1, 4}, AxisSet{0});
+ auto add_1_0 = std::make_shared<op::Add>(dot_1_0, broadcast_bias_i2h_0_0);
+
+ auto h_const_0 = op::Constant::create(element::f32, Shape{}, {1.0});
+ auto hidden_ht_0 = std::make_shared<op::Broadcast>(h_const_0, Shape{1, 1}, AxisSet{0, 1});
+ auto weights_h2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4, 1});
+ auto param2_2_reshape_0 =
+ std::make_shared<op::Reshape>(weights_h2h_0, AxisVector{1, 0}, Shape{1, 4});
+ auto dot_2_0 = std::make_shared<op::Dot>(hidden_ht_0, param2_2_reshape_0);
+
+ auto bias_h2h_0 = std::make_shared<op::Parameter>(element::f32, Shape{4});
+ auto broadcast_bias_h2h_0_0 =
+ std::make_shared<op::Broadcast>(bias_h2h_0, Shape{1, 4}, AxisSet{0});
+ auto add_2_0 = std::make_shared<op::Add>(dot_2_0, broadcast_bias_h2h_0_0);
+
+ auto X_0 = std::make_shared<op::Add>(add_2_0, add_1_0);
+ // construct forget gate
+ auto input_slice_0_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 0}, Coordinate{1, 1});
+ auto forget_gate_0 = std::make_shared<op::Sigmoid>(input_slice_0_0);
+
+ //ct-1 -> cell state
+ auto c_const_0 = op::Constant::create(element::f32, Shape{}, {1.0});
+ auto ct_1_0 = std::make_shared<op::Broadcast>(c_const_0, Shape{1, 1}, AxisSet{0, 1});
+ //auto ct_1 = std::make_shared<op::>(element::f32, Shape{10, 100});
+ auto multiply_forget_gate_0_ct_1_0 = std::make_shared<op::Multiply>(forget_gate_0, ct_1_0);
+
+ // construct input gate
+ auto input_slice_1_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 1}, Coordinate{1, 2});
+ auto input_gate_0 = std::make_shared<op::Sigmoid>(input_slice_1_0);
+ auto input_slice_2_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 2}, Coordinate{1, 3});
+ auto tanh_1_0 = std::make_shared<op::Tanh>(input_slice_2_0);
+ auto multiply_input_gate_0_tanh_1_0 = std::make_shared<op::Multiply>(input_gate_0, tanh_1_0);
+
+ auto ct_0 =
+ std::make_shared<op::Add>(multiply_forget_gate_0_ct_1_0, multiply_input_gate_0_tanh_1_0);
+
+ // construct output gate
+ auto input_slice_3_0 = std::make_shared<op::Slice>(X_0, Coordinate{0, 3}, Coordinate{1, 4});
+ auto output_gate_0 = std::make_shared<op::Sigmoid>(input_slice_3_0);
+ auto tanh_2_0 = std::make_shared<op::Tanh>(ct_0);
+ auto ht_0 = std::make_shared<op::Multiply>(output_gate_0, tanh_2_0);
+
+ auto f = make_shared<Function>(NodeVector{ht_0, ct_0},
+ op::ParameterVector{input_xt,
+ weights_i2h,
+ weights_h2h,
+ bias_i2h,
+ bias_h2h,
+ weights_i2h_0,
+ weights_h2h_0,
+ bias_i2h_0,
+ bias_h2h_0});
+
+ auto backend = runtime::Backend::create("GPU");
+
+ auto params = f->get_parameters();
+ std::vector<std::shared_ptr<ngraph::runtime::TensorView>> arg_tensors;
+ for (shared_ptr<op::Parameter> param : params)
+ {
+ vector<float> tensor_vals(shape_size(param->get_shape()), 1.0f);
+ auto tensor = backend->create_tensor(element::f32, param->get_shape());
+ copy_data(tensor, tensor_vals);
+ arg_tensors.push_back(tensor);
+ }
+
+ std::shared_ptr<runtime::TensorView> result_ht =
+ backend->create_tensor(element::f32, ht->get_shape());
+ std::shared_ptr<runtime::TensorView> result_ct =
+ backend->create_tensor(element::f32, ct->get_shape());
+
+ backend->call_with_validate(f, {result_ht, result_ct}, arg_tensors);
+ //EXPECT_EQ(1, count_ops_of_type<op::gpu::Rnn>(f));
+
+ auto sig = [](float x) { return 1.0f / (1.0f + std::exp(-x)); };
+ float kernel = 4.0f;
+ float ct_val_first = sig(kernel) + sig(kernel) * std::tanh(kernel);
+ float ht_val_first = sig(kernel) * std::tanh(ct_val_first);
+
+ kernel = 3.0f + ht_val_first;
+ float ct_val_second = sig(kernel) + sig(kernel) * std::tanh(kernel);
+ float ht_val_second = sig(kernel) * std::tanh(ct_val_second);
+
+ EXPECT_TRUE(test::all_close(std::vector<float>{ht_val_second}, read_vector<float>(result_ht)));
+ EXPECT_TRUE(test::all_close(std::vector<float>{ct_val_second}, read_vector<float>(result_ct)));
+}
+
TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1lstm_cell)
{
const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
- test::Uniform<float> rng(0.0f, 1.0f);
+ test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
@@ -338,7 +453,7 @@ TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_1rnn_layer_3lstm_cell)
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
- test::Uniform<float> rng(0.0f, 1.0f);
+ test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
@@ -360,7 +475,7 @@ TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_2rnn_layer_3lstm_cell)
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
- test::Uniform<float> rng(0.0f, 1.0f);
+ test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
@@ -373,7 +488,7 @@ TEST(gpu_fusion, rnn_fusion_inter_vs_gpu_2rnn_layer_3lstm_cell)
auto gpu_results = execute(gpu_f, args, "GPU");
for (size_t i = 0; i < gpu_results.size(); i++)
{
- EXPECT_TRUE(test::all_close(gpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
+ EXPECT_TRUE(test::all_close(gpu_results.at(i), int_results.at(i), 1.0e-3f, 1.0e-3f));
}
}
@@ -400,7 +515,7 @@ TEST(gpu_fusion, fuse_rnn_across_2layer_1timestep)
const std::string file_name("mxnet/2rnn_layer_1timestep.json");
auto gpu_f = make_function(file_name);
auto int_f = make_function(file_name);
- test::Uniform<float> rng(0.0f, 1.0f);
+ test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment