Skip to content

Commit

Permalink
Merge pull request BVLC#1946 from nickcarlevaris/msra_init
Browse files Browse the repository at this point in the history
  Add MSRAFiller, an Xavier-like filler designed for use with ReLUs
  • Loading branch information
shelhamer committed May 26, 2015
2 parents 2c69258 + 59de6c7 commit c255709
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 9 deletions.
71 changes: 62 additions & 9 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,18 @@ class PositiveUnitballFiller : public Filler<Dtype> {
};

/**
* @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$
* is set inversely proportional to the number of incoming nodes.
* @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$ is
* set inversely proportional to number of incoming nodes, outgoing
* nodes, or their average.
*
* A Filler based on the paper [Bengio and Glorot 2010]: Understanding
* the difficulty of training deep feedforward neuralnetworks, but does not
* use the fan_out value.
* the difficulty of training deep feedforward neuralnetworks.
*
* It fills the incoming matrix by randomly sampling uniform data from
* [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
* of input nodes. You should make sure the input blob has shape (num, a, b, c)
* where a * b * c = fan_in.
* It fills the incoming matrix by randomly sampling uniform data from [-scale,
* scale] where scale = sqrt(3 / n) where n is the fan_in, fan_out, or their
* average, depending on the variance_norm option. You should make sure the
* input blob has shape (num, a, b, c) where a * b * c = fan_in and num * b * c
* = fan_out. Note that this is currently not the case for inner product layers.
*
* TODO(dox): make notation in above comment consistent with rest & use LaTeX.
*/
Expand All @@ -148,14 +149,64 @@ class XavierFiller : public Filler<Dtype> {
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
Dtype scale = sqrt(Dtype(3) / fan_in);
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype scale = sqrt(Dtype(3) / n);
caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};

/**
* @brief Fills a Blob with values @f$ x \sim N(0, \sigma^2) @f$ where
* @f$ \sigma^2 @f$ is set inversely proportional to number of incoming
* nodes, outgoing nodes, or their average.
*
* A Filler based on the paper [He, Zhang, Ren and Sun 2015]: Specifically
* accounts for ReLU nonlinearities.
*
* Aside: for another perspective on the scaling factor, see the derivation of
* [Saxe, McClelland, and Ganguli 2013 (v3)].
*
* It fills the incoming matrix by randomly sampling Gaussian data with std =
* sqrt(2 / n) where n is the fan_in, fan_out, or their average, depending on
* the variance_norm option. You should make sure the input blob has shape (num,
* a, b, c) where a * b * c = fan_in and num * b * c = fan_out. Note that this
* is currently not the case for inner product layers.
*/
template <typename Dtype>
class MSRAFiller : public Filler<Dtype> {
public:
explicit MSRAFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype std = sqrt(Dtype(2) / n);
caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};

/**
* @brief Get a specific filler from the specification given in FillerParameter.
Expand All @@ -176,6 +227,8 @@ Filler<Dtype>* GetFiller(const FillerParameter& param) {
return new UniformFiller<Dtype>(param);
} else if (type == "xavier") {
return new XavierFiller<Dtype>(param);
} else if (type == "msra") {
return new MSRAFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ message FillerParameter {
// The expected number of non-zero output weights for a given input in
// Gaussian filler -- the default -1 means don't perform sparsification.
optional int32 sparse = 7 [default = -1];
// Normalize the filler variance by fan_in, fan_out, or their average.
// Applies to 'xavier' and 'msra' fillers.
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
AVERAGE = 2;
}
optional VarianceNorm variance_norm = 8 [default = FAN_IN];
}

message NetParameter {
Expand Down
98 changes: 98 additions & 0 deletions src/caffe/test/test_filler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,102 @@ TYPED_TEST(GaussianFillerTest, TestFill) {
EXPECT_LE(var, target_var * 5.);
}

template <typename Dtype>
class XavierFillerTest : public ::testing::Test {
protected:
XavierFillerTest()
: blob_(new Blob<Dtype>(1000, 2, 4, 5)),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
this->filler_param_.set_variance_norm(variance_norm);
this->filler_.reset(new XavierFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const Dtype* data = this->blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
Dtype std = sqrt(ex2 - mean*mean);
Dtype target_std = sqrt(2.0 / n);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}
virtual ~XavierFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<XavierFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(XavierFillerTest, TestDtypes);

TYPED_TEST(XavierFillerTest, TestFillFanIn) {
TypeParam n = 2*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
}
TYPED_TEST(XavierFillerTest, TestFillFanOut) {
TypeParam n = 1000*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
}
TYPED_TEST(XavierFillerTest, TestFillAverage) {
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

template <typename Dtype>
class MSRAFillerTest : public ::testing::Test {
protected:
MSRAFillerTest()
: blob_(new Blob<Dtype>(1000, 2, 4, 5)),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
this->filler_param_.set_variance_norm(variance_norm);
this->filler_.reset(new MSRAFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const Dtype* data = this->blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
Dtype std = sqrt(ex2 - mean*mean);
Dtype target_std = sqrt(2.0 / n);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}
virtual ~MSRAFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<MSRAFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(MSRAFillerTest, TestDtypes);

TYPED_TEST(MSRAFillerTest, TestFillFanIn) {
TypeParam n = 2*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
}
TYPED_TEST(MSRAFillerTest, TestFillFanOut) {
TypeParam n = 1000*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
}
TYPED_TEST(MSRAFillerTest, TestFillAverage) {
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

} // namespace caffe

0 comments on commit c255709

Please sign in to comment.