Skip to content

Instantly share code, notes, and snippets.

@xmfbit
Created January 16, 2017 08:48
Show Gist options
  • Save xmfbit/4f43f6397d109e3272d127344076e70c to your computer and use it in GitHub Desktop.
Save xmfbit/4f43f6397d109e3272d127344076e70c to your computer and use it in GitHub Desktop.
test file for reorg_layer
#include <vector>
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/blob.hpp"
#include "caffe/layers/reorg_layer.hpp"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
template <typename TypeParam>
class ReorgLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
protected:
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
ReorgLayerTest() : blob_bottom_(new Blob<Dtype>(2, 8, 2, 2)),
blob_top_(new Blob<Dtype>()) {
Dtype* bottom_data = blob_bottom_->mutable_cpu_data();
int count = blob_bottom_->count();
// n = 0: 0, 1, ..., 31
// n = 1: 32, 33, ..., 63
for(int i = 0; i < count; ++i) {
bottom_data[i] = static_cast<Dtype>(i);
}
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
}
virtual ~ReorgLayerTest() {
delete blob_bottom_;
delete blob_top_;
}
};
TYPED_TEST_CASE(ReorgLayerTest, TestDtypesAndDevices);
TYPED_TEST(ReorgLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
ReorganizeParamter* reorg_param = layer_param.mutable_reorganize_param();
reorg_param->set_stride(2);
reorg_param->set_reverse(true);
ReorgLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Test
// shape check
EXPECT_EQ(this->blob_bottom_->count(), this->blob_top_->count());
EXPECT_EQ(this->blob_bottom_->num(), this->blob_top_->num());
EXPECT_EQ(this->blob_bottom_->channels(), this->blob_top_->channels() * 4);
EXPECT_EQ(this->blob_bottom_->height(), this->blob_top_->height() / 2);
EXPECT_EQ(this->blob_bottom_->width(), this->blob_top_->width() / 2);
// check number
int val = 0;
for(int n = 0; n < this->blob_top_->num(); ++n) {
for(int c = 0; c < this->blob_top_->channels(); ++c) {
for(int h = 0; h < this->blob_top_->height(); ++h) {
for(int w = 0; w < this->blob_top_->width(); ++w) {
//int offset = this->blob_top_->offset(n, c, h, w);
int tw = w / 2;
int th = h / 2;
int offset = (w - tw*2) + (h - th*2)*2;
int tc = c + offset * this->blob_top_->channels();
EXPECT_EQ(this->blob_bottom_->data_at(n, tc, th, tw),
this->blob_top_->data_at(n, c, h, w));
}
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment