Skip to content

Commit

Permalink
Merge branch 'master' into patch_1
Browse files Browse the repository at this point in the history
  • Loading branch information
Noiredd authored Aug 21, 2018
2 parents f019d0d + 8e97b8a commit 828dd10
Show file tree
Hide file tree
Showing 27 changed files with 388 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ caffe_option(USE_LMDB "Build with lmdb" ON)
caffe_option(ALLOW_LMDB_NOLOCK "Allow MDB_NOLOCK when reading LMDB files (only if necessary)" OFF)
caffe_option(USE_OPENMP "Link with OpenMP (when your BLAS wants OpenMP and you get linker errors)" OFF)

# This code is taken from https://github.com/sh1r0/caffe-android-lib
caffe_option(USE_HDF5 "Build with hdf5" ON)

# ---[ Dependencies
include(cmake/Dependencies.cmake)

Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,13 @@ ifneq ($(CPU_ONLY), 1)
LIBRARIES := cudart cublas curand
endif

LIBRARIES += glog gflags protobuf boost_system boost_filesystem m hdf5_hl hdf5
LIBRARIES += glog gflags protobuf boost_system boost_filesystem m

# handle IO dependencies
USE_LEVELDB ?= 1
USE_LMDB ?= 1
# This code is taken from https://github.com/sh1r0/caffe-android-lib
USE_HDF5 ?= 1
USE_OPENCV ?= 1

ifeq ($(USE_LEVELDB), 1)
Expand All @@ -191,6 +193,10 @@ endif
ifeq ($(USE_LMDB), 1)
LIBRARIES += lmdb
endif
# This code is taken from https://github.com/sh1r0/caffe-android-lib
ifeq ($(USE_HDF5), 1)
LIBRARIES += hdf5_hl hdf5
endif
ifeq ($(USE_OPENCV), 1)
LIBRARIES += opencv_core opencv_highgui opencv_imgproc

Expand Down Expand Up @@ -347,6 +353,10 @@ ifeq ($(ALLOW_LMDB_NOLOCK), 1)
COMMON_FLAGS += -DALLOW_LMDB_NOLOCK
endif
endif
# This code is taken from https://github.com/sh1r0/caffe-android-lib
ifeq ($(USE_HDF5), 1)
COMMON_FLAGS += -DUSE_HDF5
endif

# CPU-only configuration
ifeq ($(CPU_ONLY), 1)
Expand Down
2 changes: 2 additions & 0 deletions Makefile.config.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# USE_OPENCV := 0
# USE_LEVELDB := 0
# USE_LMDB := 0
# This code is taken from https://github.com/sh1r0/caffe-android-lib
# USE_HDF5 := 0

# uncomment to allow MDB_NOLOCK when reading LMDB files (only if necessary)
# You should not set this flag if you will be reading LMDBs with any
Expand Down
12 changes: 12 additions & 0 deletions cmake/ConfigGen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ function(caffe_generate_export_configs)
set(HAVE_CUDA FALSE)
endif()

set(HDF5_IMPORTED OFF)
foreach(_lib ${HDF5_LIBRARIES} ${HDF5_HL_LIBRARIES})
if(TARGET ${_lib})
set(HDF5_IMPORTED ON)
endif()
endforeach()

# This code is taken from https://github.com/sh1r0/caffe-android-lib
if(USE_HDF5)
list(APPEND Caffe_DEFINITIONS -DUSE_HDF5)
endif()

if(NOT HAVE_CUDNN)
set(HAVE_CUDNN FALSE)
endif()
Expand Down
8 changes: 8 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ find_package(HDF5 COMPONENTS HL REQUIRED)
list(APPEND Caffe_INCLUDE_DIRS PUBLIC ${HDF5_INCLUDE_DIRS})
list(APPEND Caffe_LINKER_LIBS PUBLIC ${HDF5_LIBRARIES} ${HDF5_HL_LIBRARIES})

# This code is taken from https://github.com/sh1r0/caffe-android-lib
if(USE_HDF5)
find_package(HDF5 COMPONENTS HL REQUIRED)
include_directories(SYSTEM ${HDF5_INCLUDE_DIRS} ${HDF5_HL_INCLUDE_DIR})
list(APPEND Caffe_LINKER_LIBS ${HDF5_LIBRARIES} ${HDF5_HL_LIBRARIES})
add_definitions(-DUSE_HDF5)
endif()

# ---[ LMDB
if(USE_LMDB)
find_package(LMDB REQUIRED)
Expand Down
2 changes: 2 additions & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ function(caffe_print_configuration_summary)
caffe_status(" USE_LMDB : ${USE_LMDB}")
caffe_status(" USE_NCCL : ${USE_NCCL}")
caffe_status(" ALLOW_LMDB_NOLOCK : ${ALLOW_LMDB_NOLOCK}")
# This code is taken from https://github.com/sh1r0/caffe-android-lib
caffe_status(" USE_HDF5 : ${USE_HDF5}")
caffe_status("")
caffe_status("Dependencies:")
caffe_status(" BLAS : " APPLE THEN "Yes (vecLib)" ELSE "Yes (${BLAS})")
Expand Down
1 change: 1 addition & 0 deletions docs/tutorial/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Layers:
* [Log](layers/log.html) - f(x) = log(x).
* [BNLL](layers/bnll.html) - f(x) = log(1 + exp(x)).
* [Threshold](layers/threshold.html) - performs step function at user defined threshold.
* [Clip](layers/clip.html) - clips a blob between a fixed minimum and maximum value.
* [Bias](layers/bias.html) - adds a bias to a blob that can either be learned or fixed.
* [Scale](layers/scale.html) - scales a blob by an amount that can either be learned or fixed.

Expand Down
20 changes: 20 additions & 0 deletions docs/tutorial/layers/clip.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
title: Clip Layer
---

# Clip Layer

* Layer type: `Clip`
* [Doxygen Documentation](http://caffe.berkeleyvision.org/doxygen/classcaffe_1_1ClipLayer.html)
* Header: [`./include/caffe/layers/clip_layer.hpp`](https://github.com/BVLC/caffe/blob/master/include/caffe/layers/clip_layer.hpp)
* CPU implementation: [`./src/caffe/layers/clip_layer.cpp`](https://github.com/BVLC/caffe/blob/master/src/caffe/layers/clip_layer.cpp)
* CUDA GPU implementation: [`./src/caffe/layers/clip_layer.cu`](https://github.com/BVLC/caffe/blob/master/src/caffe/layers/clip_layer.cu)

## Parameters

* Parameters (`ClipParameter clip_param`)
* From [`./src/caffe/proto/caffe.proto`](https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto):

{% highlight Protobuf %}
{% include proto/ClipParameter.txt %}
{% endhighlight %}
75 changes: 75 additions & 0 deletions include/caffe/layers/clip_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#ifndef CAFFE_CLIP_LAYER_HPP_
#define CAFFE_CLIP_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/neuron_layer.hpp"

namespace caffe {

/**
* @brief Clip: @f$ y = \max(min, \min(max, x)) @f$.
*/
template <typename Dtype>
class ClipLayer : public NeuronLayer<Dtype> {
public:
/**
* @param param provides ClipParameter clip_param,
* with ClipLayer options:
* - min
* - max
*/
explicit ClipLayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param) {}

virtual inline const char* type() const { return "Clip"; }

protected:
/**
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$
* @param top output Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the computed outputs @f$
* y = \max(min, \min(max, x))
* @f$
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

/**
* @brief Computes the error gradient w.r.t. the clipped inputs.
*
* @param top output Blob vector (length 1), providing the error gradient with
* respect to the outputs
* -# @f$ (N \times C \times H \times W) @f$
* containing error gradients @f$ \frac{\partial E}{\partial y} @f$
* with respect to computed outputs @f$ y @f$
* @param propagate_down see Layer::Backward.
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$; Backward fills their diff with
* gradients @f$
* \frac{\partial E}{\partial x} = \left\{
* \begin{array}{lr}
* 0 & \mathrm{if} \; x < min \vee x > max \\
* \frac{\partial E}{\partial y} & \mathrm{if} \; x \ge min \wedge x \le max
* \end{array} \right.
* @f$
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
};

} // namespace caffe

#endif // CAFFE_CLIP_LAYER_HPP_
1 change: 1 addition & 0 deletions include/caffe/layers/pooling_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class PoolingLayer : public Layer<Dtype> {
int height_, width_;
int pooled_height_, pooled_width_;
bool global_pooling_;
PoolingParameter_RoundMode round_mode_;
Blob<Dtype> rand_idx_;
Blob<int> max_idx_;
};
Expand Down
2 changes: 2 additions & 0 deletions include/caffe/util/hdf5.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef USE_HDF5
#ifndef CAFFE_UTIL_HDF5_H_
#define CAFFE_UTIL_HDF5_H_

Expand Down Expand Up @@ -37,3 +38,4 @@ string hdf5_get_name_by_idx(hid_t loc_id, int idx);
} // namespace caffe

#endif // CAFFE_UTIL_HDF5_H_
#endif // USE_HDF5
1 change: 1 addition & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/conv_layer.hpp"
#include "caffe/layers/deconv_layer.hpp"
#include "caffe/layers/lrn_layer.hpp"
Expand Down
51 changes: 51 additions & 0 deletions src/caffe/layers/clip_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <algorithm>
#include <vector>

#include "caffe/layers/clip_layer.hpp"

namespace caffe {

template <typename Dtype>
void ClipLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const int count = bottom[0]->count();

Dtype min = this->layer_param_.clip_param().min();
Dtype max = this->layer_param_.clip_param().max();

for (int i = 0; i < count; ++i) {
top_data[i] = std::max(min, std::min(bottom_data[i], max));
}
}

template <typename Dtype>
void ClipLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
const int count = bottom[0]->count();

Dtype min = this->layer_param_.clip_param().min();
Dtype max = this->layer_param_.clip_param().max();

for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * (
bottom_data[i] >= min && bottom_data[i] <= max);
}
}
}


#ifdef CPU_ONLY
STUB_GPU(ClipLayer);
#endif

INSTANTIATE_CLASS(ClipLayer);
REGISTER_LAYER_CLASS(Clip);

} // namespace caffe
67 changes: 67 additions & 0 deletions src/caffe/layers/clip_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <vector>

#include "caffe/layers/clip_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

__global__ void ClipForward(const int n, const float* in, float* out,
float p_min, float p_max) {
CUDA_KERNEL_LOOP(index, n) {
out[index] = fmaxf(p_min, fminf(in[index], p_max));
}
}

__global__ void ClipForward(const int n, const double* in, double* out,
double p_min, double p_max) {
CUDA_KERNEL_LOOP(index, n) {
out[index] = fmax(p_min, fmin(in[index], p_max));
}
}

template <typename Dtype>
void ClipLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int count = bottom[0]->count();
Dtype p_min = this->layer_param_.clip_param().min();
Dtype p_max = this->layer_param_.clip_param().max();
// NOLINT_NEXT_LINE(whitespace/operators)
ClipForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data, p_min, p_max);
CUDA_POST_KERNEL_CHECK;
}

template <typename Dtype>
__global__ void ClipBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff, Dtype p_min, Dtype p_max) {
CUDA_KERNEL_LOOP(index, n) {
out_diff[index] = in_diff[index] * (
in_data[index] >= p_min && in_data[index] <= p_max);
}
}

template <typename Dtype>
void ClipLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int count = bottom[0]->count();
Dtype p_min = this->layer_param_.clip_param().min();
Dtype p_max = this->layer_param_.clip_param().max();
// NOLINT_NEXT_LINE(whitespace/operators)
ClipBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff, p_min, p_max);
CUDA_POST_KERNEL_CHECK;
}
}


INSTANTIATE_LAYER_GPU_FUNCS(ClipLayer);


} // namespace caffe
2 changes: 2 additions & 0 deletions src/caffe/layers/hdf5_data_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef USE_HDF5
/*
TODO:
- load file in a separate thread ("prefetch")
Expand Down Expand Up @@ -184,3 +185,4 @@ INSTANTIATE_CLASS(HDF5DataLayer);
REGISTER_LAYER_CLASS(HDF5Data);

} // namespace caffe
#endif // USE_HDF5
2 changes: 2 additions & 0 deletions src/caffe/layers/hdf5_data_layer.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef USE_HDF5
/*
TODO:
- only load parts of the file, in accordance with a prototxt param "max_mem"
Expand Down Expand Up @@ -34,3 +35,4 @@ void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
INSTANTIATE_LAYER_GPU_FUNCS(HDF5DataLayer);

} // namespace caffe
#endif // USE_HDF5
2 changes: 2 additions & 0 deletions src/caffe/layers/hdf5_output_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef USE_HDF5
#include <vector>

#include "hdf5.h"
Expand Down Expand Up @@ -72,3 +73,4 @@ INSTANTIATE_CLASS(HDF5OutputLayer);
REGISTER_LAYER_CLASS(HDF5Output);

} // namespace caffe
#endif // USE_HDF5
2 changes: 2 additions & 0 deletions src/caffe/layers/hdf5_output_layer.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef USE_HDF5
#include <vector>

#include "hdf5.h"
Expand Down Expand Up @@ -37,3 +38,4 @@ void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
INSTANTIATE_LAYER_GPU_FUNCS(HDF5OutputLayer);

} // namespace caffe
#endif // USE_HDF5
2 changes: 1 addition & 1 deletion src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.inner_product_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, intiialize and fill the bias term
// If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
Expand Down
Loading

0 comments on commit 828dd10

Please sign in to comment.