forked from intel/caffe
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2966 from cdoersch/batch_reindex_layer
BatchReindexLayer to shuffle, subsample, and replicate examples in a batch
- Loading branch information
Showing
4 changed files
with
374 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#include <vector> | ||
|
||
#include "caffe/layer.hpp" | ||
#include "caffe/util/math_functions.hpp" | ||
#include "caffe/vision_layers.hpp" | ||
|
||
namespace caffe { | ||
|
||
template<typename Dtype> | ||
void BatchReindexLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
CHECK_EQ(1, bottom[1]->num_axes()); | ||
vector<int> newshape; | ||
newshape.push_back(bottom[1]->shape(0)); | ||
for (int i = 1; i < bottom[0]->shape().size(); ++i) { | ||
newshape.push_back(bottom[0]->shape()[i]); | ||
} | ||
top[0]->Reshape(newshape); | ||
} | ||
|
||
template<typename Dtype> | ||
void BatchReindexLayer<Dtype>::check_batch_reindex(int initial_num, | ||
int final_num, | ||
const Dtype* ridx_data) { | ||
for (int i = 0; i < final_num; ++i) { | ||
CHECK_GE(ridx_data[i], 0) | ||
<< "Index specified for reindex layer was negative."; | ||
CHECK_LT(ridx_data[i], initial_num) | ||
<< "Index specified for reindex layer was greater than batch size."; | ||
} | ||
} | ||
|
||
template<typename Dtype> | ||
void BatchReindexLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
check_batch_reindex(bottom[0]->shape(0), bottom[1]->count(), | ||
bottom[1]->cpu_data()); | ||
if (top[0]->count() == 0) { | ||
return; | ||
} | ||
int inner_dim = bottom[0]->count() / bottom[0]->shape(0); | ||
const Dtype* in = bottom[0]->cpu_data(); | ||
const Dtype* permut = bottom[1]->cpu_data(); | ||
Dtype* out = top[0]->mutable_cpu_data(); | ||
for (int index = 0; index < top[0]->count(); ++index) { | ||
int n = index / (inner_dim); | ||
int in_n = static_cast<int>(permut[n]); | ||
out[index] = in[in_n * (inner_dim) + index % (inner_dim)]; | ||
} | ||
} | ||
|
||
template<typename Dtype> | ||
void BatchReindexLayer<Dtype>::Backward_cpu( | ||
const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, | ||
const vector<Blob<Dtype>*>& bottom) { | ||
CHECK(!propagate_down[1]) << "Cannot backprop to index."; | ||
if (!propagate_down[0]) { | ||
return; | ||
} | ||
int inner_dim = bottom[0]->count() / bottom[0]->shape(0); | ||
Dtype* bot_diff = bottom[0]->mutable_cpu_diff(); | ||
const Dtype* permut = bottom[1]->cpu_data(); | ||
const Dtype* top_diff = top[0]->cpu_diff(); | ||
caffe_set(bottom[0]->count(), Dtype(0), bot_diff); | ||
for (int index = 0; index < top[0]->count(); ++index) { | ||
int n = index / (inner_dim); | ||
int in_n = static_cast<int>(permut[n]); | ||
bot_diff[in_n * (inner_dim) + index % (inner_dim)] += top_diff[index]; | ||
} | ||
} | ||
|
||
#ifdef CPU_ONLY | ||
STUB_GPU(BatchReindexLayer); | ||
#endif | ||
|
||
INSTANTIATE_CLASS(BatchReindexLayer); | ||
REGISTER_LAYER_CLASS(BatchReindex); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#include <algorithm> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "caffe/layer.hpp" | ||
#include "caffe/util/math_functions.hpp" | ||
#include "caffe/vision_layers.hpp" | ||
|
||
namespace caffe { | ||
|
||
template<typename Dtype> | ||
__global__ void BRForward(const int count, const int inner_dim, const Dtype* in, | ||
const Dtype* permut, Dtype* out) { | ||
CUDA_KERNEL_LOOP(index, count) { | ||
int n = index / (inner_dim); | ||
int in_n = static_cast<int>(permut[n]); | ||
out[index] = in[in_n * (inner_dim) + index % (inner_dim)]; | ||
} | ||
} | ||
|
||
template<typename Dtype> | ||
void BatchReindexLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top) { | ||
check_batch_reindex(bottom[0]->shape(0), bottom[1]->count(), | ||
bottom[1]->cpu_data()); | ||
if (top[0]->count() == 0) { | ||
return; | ||
} | ||
int threads = top[0]->count(); | ||
// NOLINT_NEXT_LINE(whitespace/operators) | ||
BRForward<Dtype> <<<CAFFE_GET_BLOCKS(threads), CAFFE_CUDA_NUM_THREADS>>>( | ||
top[0]->count(), bottom[0]->count() / bottom[0]->shape(0), | ||
bottom[0]->gpu_data(), bottom[1]->gpu_data(), top[0]->mutable_gpu_data()); | ||
CUDA_POST_KERNEL_CHECK; | ||
} | ||
|
||
template<typename Dtype> | ||
__global__ void BRBackward(const int count, const int inner_dim, | ||
const Dtype* in, const Dtype* top_indexes, | ||
const Dtype* begins, const Dtype* counts, | ||
Dtype* out) { | ||
CUDA_KERNEL_LOOP(index, count) { | ||
int n = index / (inner_dim); | ||
out[index] = 0; | ||
int lower = static_cast<int>(begins[n]); | ||
int upper = lower + static_cast<int>(counts[n]); | ||
for (int i = lower; i < upper; ++i) { | ||
int in_n = static_cast<int>(top_indexes[i]); | ||
out[index] += in[in_n * (inner_dim) + index % (inner_dim)]; | ||
} | ||
} | ||
} | ||
|
||
template<typename Dtype> | ||
void BatchReindexLayer<Dtype>::Backward_gpu( | ||
const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, | ||
const vector<Blob<Dtype>*>& bottom) { | ||
CHECK(!propagate_down[1]) << "Cannot backprop to index."; | ||
if (!propagate_down[0]) { | ||
return; | ||
} | ||
|
||
vector<std::pair<int, int> > mapping; | ||
const Dtype* perm = bottom[1]->cpu_data(); | ||
for (int i = 0; i < bottom[1]->count(); ++i) { | ||
mapping.push_back(pair<int, int>(static_cast<int>(perm[i]), i)); | ||
} | ||
std::sort(mapping.begin(), mapping.end(), pair_sort_first()); | ||
|
||
// Each element of the bottom diff is potentially the sum of many top diffs. | ||
// However, we'd like each CUDA thread to handle exactly one output. Hence, | ||
// we first pre-compute a list of lists of indices that need to be summed for | ||
// each output. `top_indexes` holds the data of this list of lists. The | ||
// k'th element of `begins` points to the location in `top_indexes` where the | ||
// list for the k'th example begin, and the k'th element of `counts` is the | ||
// length of that list. | ||
vector<int> shape; | ||
shape.push_back(bottom[1]->count()); | ||
Blob<Dtype> top_indexes(shape); | ||
shape[0] = bottom[0]->shape(0); | ||
Blob<Dtype> counts(shape); | ||
Blob<Dtype> begins(shape); | ||
Dtype* t_i_data = top_indexes.mutable_cpu_data(); | ||
Dtype* c_data = counts.mutable_cpu_data(); | ||
Dtype* b_data = begins.mutable_cpu_data(); | ||
caffe_set(begins.count(), Dtype(-1), b_data); | ||
caffe_set(counts.count(), Dtype(0), c_data); | ||
for (int i = 0; i < mapping.size(); ++i) { | ||
t_i_data[i] = mapping[i].second; | ||
if (b_data[mapping[i].first] == -1) { | ||
b_data[mapping[i].first] = i; | ||
} | ||
c_data[mapping[i].first] += 1; | ||
} | ||
|
||
int threads = bottom[0]->count(); | ||
// NOLINT_NEXT_LINE(whitespace/operators) | ||
BRBackward<Dtype> <<<CAFFE_GET_BLOCKS(threads), CAFFE_CUDA_NUM_THREADS>>>( | ||
bottom[0]->count(), bottom[0]->count() / bottom[0]->shape(0), | ||
top[0]->gpu_diff(), top_indexes.gpu_data(), begins.gpu_data(), | ||
counts.gpu_data(), bottom[0]->mutable_gpu_diff()); | ||
CUDA_POST_KERNEL_CHECK; | ||
} | ||
|
||
INSTANTIATE_LAYER_GPU_FUNCS(BatchReindexLayer); | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#include <cstring> | ||
#include <vector> | ||
|
||
#include "gtest/gtest.h" | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/common.hpp" | ||
#include "caffe/filler.hpp" | ||
#include "caffe/vision_layers.hpp" | ||
|
||
#include "caffe/test/test_caffe_main.hpp" | ||
#include "caffe/test/test_gradient_check_util.hpp" | ||
|
||
namespace caffe { | ||
|
||
template<typename TypeParam> | ||
class BatchReindexLayerTest : public MultiDeviceTest<TypeParam> { | ||
typedef typename TypeParam::Dtype Dtype; | ||
|
||
protected: | ||
BatchReindexLayerTest() | ||
: blob_bottom_(new Blob<Dtype>()), | ||
blob_bottom_permute_(new Blob<Dtype>()), | ||
blob_top_(new Blob<Dtype>()) { | ||
} | ||
virtual void SetUp() { | ||
Caffe::set_random_seed(1701); | ||
vector<int> sz; | ||
sz.push_back(5); | ||
sz.push_back(4); | ||
sz.push_back(3); | ||
sz.push_back(2); | ||
blob_bottom_->Reshape(sz); | ||
vector<int> permsz; | ||
permsz.push_back(6); | ||
blob_bottom_permute_->Reshape(permsz); | ||
|
||
// fill the values | ||
FillerParameter filler_param; | ||
GaussianFiller<Dtype> filler(filler_param); | ||
filler.Fill(this->blob_bottom_); | ||
int perm[] = { 4, 0, 4, 0, 1, 2 }; | ||
for (int i = 0; i < blob_bottom_permute_->count(); ++i) { | ||
blob_bottom_permute_->mutable_cpu_data()[i] = perm[i]; | ||
} | ||
|
||
blob_bottom_vec_.push_back(blob_bottom_); | ||
blob_bottom_vec_.push_back(blob_bottom_permute_); | ||
blob_top_vec_.push_back(blob_top_); | ||
} | ||
virtual ~BatchReindexLayerTest() { | ||
delete blob_bottom_permute_; | ||
delete blob_bottom_; | ||
delete blob_top_; | ||
} | ||
Blob<Dtype>* const blob_bottom_; | ||
Blob<Dtype>* const blob_bottom_permute_; | ||
Blob<Dtype>* const blob_top_; | ||
vector<Blob<Dtype>*> blob_bottom_vec_; | ||
vector<Blob<Dtype>*> blob_top_vec_; | ||
|
||
void TestForward() { | ||
LayerParameter layer_param; | ||
|
||
vector<int> sz; | ||
sz.push_back(5); | ||
sz.push_back(4); | ||
sz.push_back(3); | ||
sz.push_back(2); | ||
blob_bottom_->Reshape(sz); | ||
for (int i = 0; i < blob_bottom_->count(); ++i) { | ||
blob_bottom_->mutable_cpu_data()[i] = i; | ||
} | ||
|
||
vector<int> permsz; | ||
permsz.push_back(6); | ||
blob_bottom_permute_->Reshape(permsz); | ||
int perm[] = { 4, 0, 4, 0, 1, 2 }; | ||
for (int i = 0; i < blob_bottom_permute_->count(); ++i) { | ||
blob_bottom_permute_->mutable_cpu_data()[i] = perm[i]; | ||
} | ||
BatchReindexLayer<Dtype> layer(layer_param); | ||
layer.SetUp(blob_bottom_vec_, blob_top_vec_); | ||
EXPECT_EQ(blob_top_->num(), blob_bottom_permute_->num()); | ||
EXPECT_EQ(blob_top_->channels(), blob_bottom_->channels()); | ||
EXPECT_EQ(blob_top_->height(), blob_bottom_->height()); | ||
EXPECT_EQ(blob_top_->width(), blob_bottom_->width()); | ||
|
||
layer.Forward(blob_bottom_vec_, blob_top_vec_); | ||
int channels = blob_top_->channels(); | ||
int height = blob_top_->height(); | ||
int width = blob_top_->width(); | ||
for (int i = 0; i < blob_top_->count(); ++i) { | ||
int n = i / (channels * width * height); | ||
int inner_idx = (i % (channels * width * height)); | ||
EXPECT_EQ( | ||
blob_top_->cpu_data()[i], | ||
blob_bottom_->cpu_data()[perm[n] * channels * width * height | ||
+ inner_idx]); | ||
} | ||
} | ||
}; | ||
|
||
TYPED_TEST_CASE(BatchReindexLayerTest, TestDtypesAndDevices); | ||
|
||
TYPED_TEST(BatchReindexLayerTest, TestForward) { | ||
this->TestForward(); | ||
} | ||
|
||
TYPED_TEST(BatchReindexLayerTest, TestGradient) { | ||
typedef typename TypeParam::Dtype Dtype; | ||
LayerParameter layer_param; | ||
BatchReindexLayer<Dtype> layer(layer_param); | ||
GradientChecker<Dtype> checker(1e-4, 1e-2); | ||
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, | ||
this->blob_top_vec_, 0); | ||
} | ||
|
||
} // namespace caffe |