Created
January 16, 2017 08:48
-
-
Save xmfbit/4f43f6397d109e3272d127344076e70c to your computer and use it in GitHub Desktop.
test file for reorg_layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <vector> | |
#include "gtest/gtest.h" | |
#include "caffe/common.hpp" | |
#include "caffe/blob.hpp" | |
#include "caffe/layers/reorg_layer.hpp" | |
#include "caffe/test/test_caffe_main.hpp" | |
namespace caffe { | |
template <typename TypeParam> | |
class ReorgLayerTest : public MultiDeviceTest<TypeParam> { | |
typedef typename TypeParam::Dtype Dtype; | |
protected: | |
Blob<Dtype>* const blob_bottom_; | |
Blob<Dtype>* const blob_top_; | |
vector<Blob<Dtype>*> blob_bottom_vec_; | |
vector<Blob<Dtype>*> blob_top_vec_; | |
ReorgLayerTest() : blob_bottom_(new Blob<Dtype>(2, 8, 2, 2)), | |
blob_top_(new Blob<Dtype>()) { | |
Dtype* bottom_data = blob_bottom_->mutable_cpu_data(); | |
int count = blob_bottom_->count(); | |
// n = 0: 0, 1, ..., 31 | |
// n = 1: 32, 33, ..., 63 | |
for(int i = 0; i < count; ++i) { | |
bottom_data[i] = static_cast<Dtype>(i); | |
} | |
blob_bottom_vec_.push_back(blob_bottom_); | |
blob_top_vec_.push_back(blob_top_); | |
} | |
virtual ~ReorgLayerTest() { | |
delete blob_bottom_; | |
delete blob_top_; | |
} | |
}; | |
TYPED_TEST_CASE(ReorgLayerTest, TestDtypesAndDevices); | |
TYPED_TEST(ReorgLayerTest, TestForward) { | |
typedef typename TypeParam::Dtype Dtype; | |
LayerParameter layer_param; | |
ReorganizeParamter* reorg_param = layer_param.mutable_reorganize_param(); | |
reorg_param->set_stride(2); | |
reorg_param->set_reverse(true); | |
ReorgLayer<Dtype> layer(layer_param); | |
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); | |
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); | |
// Test | |
// shape check | |
EXPECT_EQ(this->blob_bottom_->count(), this->blob_top_->count()); | |
EXPECT_EQ(this->blob_bottom_->num(), this->blob_top_->num()); | |
EXPECT_EQ(this->blob_bottom_->channels(), this->blob_top_->channels() * 4); | |
EXPECT_EQ(this->blob_bottom_->height(), this->blob_top_->height() / 2); | |
EXPECT_EQ(this->blob_bottom_->width(), this->blob_top_->width() / 2); | |
// check number | |
int val = 0; | |
for(int n = 0; n < this->blob_top_->num(); ++n) { | |
for(int c = 0; c < this->blob_top_->channels(); ++c) { | |
for(int h = 0; h < this->blob_top_->height(); ++h) { | |
for(int w = 0; w < this->blob_top_->width(); ++w) { | |
//int offset = this->blob_top_->offset(n, c, h, w); | |
int tw = w / 2; | |
int th = h / 2; | |
int offset = (w - tw*2) + (h - th*2)*2; | |
int tc = c + offset * this->blob_top_->channels(); | |
EXPECT_EQ(this->blob_bottom_->data_at(n, tc, th, tw), | |
this->blob_top_->data_at(n, c, h, w)); | |
} | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment