Skip to content

Commit

Permalink
Switched multi-GPU to NCCL
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Jan 6, 2017
1 parent 2317fa1 commit 3ba2054
Show file tree
Hide file tree
Showing 48 changed files with 813 additions and 873 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include(cmake/ConfigGen.cmake)
# ---[ Options
caffe_option(CPU_ONLY "Build Caffe without CUDA support" OFF) # TODO: rename to USE_CUDA
caffe_option(USE_CUDNN "Build Caffe with cuDNN library support" ON IF NOT CPU_ONLY)
caffe_option(USE_NCCL "Build Caffe with NCCL library support" OFF)
caffe_option(BUILD_SHARED_LIBS "Build shared libraries" ON)
caffe_option(BUILD_python "Build Python wrapper" ON)
set(python_version "2" CACHE STRING "Specify which Python version to use")
Expand Down
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ ifeq ($(USE_CUDNN), 1)
COMMON_FLAGS += -DUSE_CUDNN
endif

# NCCL acceleration configuration
ifeq ($(USE_NCCL), 1)
LIBRARIES += nccl
COMMON_FLAGS += -DUSE_NCCL
endif

# configure IO libraries
ifeq ($(USE_OPENCV), 1)
COMMON_FLAGS += -DUSE_OPENCV
Expand Down
4 changes: 4 additions & 0 deletions Makefile.config.example
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib
# INCLUDE_DIRS += $(shell brew --prefix)/include
# LIBRARY_DIRS += $(shell brew --prefix)/lib

# NCCL acceleration switch (uncomment to build with NCCL)
# https://github.com/NVIDIA/nccl (last tested version: v1.2.3-1+cuda8.0)
# USE_NCCL := 1

# Uncomment to use `pkg-config` to specify OpenCV library paths.
# (Usually not necessary -- OpenCV libraries are normally installed in one of the above $LIBRARY_DIRS.)
# USE_PKG_CONFIG := 1
Expand Down
15 changes: 11 additions & 4 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ if(NOT HAVE_CUDA)
add_definitions(-DCPU_ONLY)
endif()

if(USE_NCCL)
find_package(NCCL REQUIRED)
include_directories(SYSTEM ${NCCL_INCLUDE_DIR})
list(APPEND Caffe_LINKER_LIBS ${NCCL_LIBRARIES})
add_definitions(-DUSE_NCCL)
endif()

# ---[ OpenCV
if(USE_OPENCV)
find_package(OpenCV QUIET COMPONENTS core highgui imgproc imgcodecs)
Expand Down Expand Up @@ -119,18 +126,18 @@ if(BUILD_python)
find_package(NumPy 1.7.1)
# Find the matching boost python implementation
set(version ${PYTHONLIBS_VERSION_STRING})

STRING( REGEX REPLACE "[^0-9]" "" boost_py_version ${version} )
find_package(Boost 1.46 COMPONENTS "python-py${boost_py_version}")
set(Boost_PYTHON_FOUND ${Boost_PYTHON-PY${boost_py_version}_FOUND})

while(NOT "${version}" STREQUAL "" AND NOT Boost_PYTHON_FOUND)
STRING( REGEX REPLACE "([0-9.]+).[0-9]+" "\\1" version ${version} )

STRING( REGEX REPLACE "[^0-9]" "" boost_py_version ${version} )
find_package(Boost 1.46 COMPONENTS "python-py${boost_py_version}")
set(Boost_PYTHON_FOUND ${Boost_PYTHON-PY${boost_py_version}_FOUND})

STRING( REGEX MATCHALL "([0-9.]+).[0-9]+" has_more_version ${version} )
if("${has_more_version}" STREQUAL "")
break()
Expand Down
26 changes: 26 additions & 0 deletions cmake/Modules/FindNCCL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
set(NCCL_INC_PATHS
/usr/include
/usr/local/include
$ENV{NCCL_DIR}/include
)

set(NCCL_LIB_PATHS
/lib
/lib64
/usr/lib
/usr/lib64
/usr/local/lib
/usr/local/lib64
$ENV{NCCL_DIR}/lib
)

find_path(NCCL_INCLUDE_DIR NAMES nccl.h PATHS ${NCCL_INC_PATHS})
find_library(NCCL_LIBRARIES NAMES nccl PATHS ${NCCL_LIB_PATHS})

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARIES)

if (NCCL_FOUND)
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARIES)
endif ()
1 change: 1 addition & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ function(caffe_print_configuration_summary)
caffe_status(" USE_OPENCV : ${USE_OPENCV}")
caffe_status(" USE_LEVELDB : ${USE_LEVELDB}")
caffe_status(" USE_LMDB : ${USE_LMDB}")
caffe_status(" USE_NCCL : ${USE_NCCL}")
caffe_status(" ALLOW_LMDB_NOLOCK : ${ALLOW_LMDB_NOLOCK}")
caffe_status("")
caffe_status("Dependencies:")
Expand Down
1 change: 1 addition & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class Blob {
void set_cpu_data(Dtype* data);
const int* gpu_shape() const;
const Dtype* gpu_data() const;
void set_gpu_data(Dtype* data);
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
Dtype* mutable_cpu_data();
Expand Down
14 changes: 10 additions & 4 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,14 @@ class Caffe {
// Search from start_id to the highest possible device ordinal,
// return the ordinal of the first available device.
static int FindDevice(const int start_id = 0);
// Parallel training info
// Parallel training
inline static int solver_count() { return Get().solver_count_; }
inline static void set_solver_count(int val) { Get().solver_count_ = val; }
inline static bool root_solver() { return Get().root_solver_; }
inline static void set_root_solver(bool val) { Get().root_solver_ = val; }
inline static int solver_rank() { return Get().solver_rank_; }
inline static void set_solver_rank(int val) { Get().solver_rank_ = val; }
inline static bool multiprocess() { return Get().multiprocess_; }
inline static void set_multiprocess(bool val) { Get().multiprocess_ = val; }
inline static bool root_solver() { return Get().solver_rank_ == 0; }

protected:
#ifndef CPU_ONLY
Expand All @@ -172,8 +175,11 @@ class Caffe {
shared_ptr<RNG> random_generator_;

Brew mode_;

// Parallel training
int solver_count_;
bool root_solver_;
int solver_rank_;
bool multiprocess_;

private:
// The private constructor to avoid duplicate instantiation.
Expand Down
82 changes: 0 additions & 82 deletions include/caffe/data_reader.hpp

This file was deleted.

4 changes: 2 additions & 2 deletions include/caffe/internal_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class InternalThread {
bool must_stop();

private:
void entry(int device, Caffe::Brew mode, int rand_seed, int solver_count,
bool root_solver);
void entry(int device, Caffe::Brew mode, int rand_seed,
int solver_count, int solver_rank, bool multiprocess);

shared_ptr<boost::thread> thread_;
};
Expand Down
43 changes: 1 addition & 42 deletions include/caffe/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Layer {
* layer.
*/
explicit Layer(const LayerParameter& param)
: layer_param_(param), is_shared_(false) {
: layer_param_(param) {
// Set phase and copy blobs (if there are any).
phase_ = param.phase();
if (layer_param_.blobs_size() > 0) {
Expand Down Expand Up @@ -66,7 +66,6 @@ class Layer {
*/
void SetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
InitMutex();
CheckBlobCounts(bottom, top);
LayerSetUp(bottom, top);
Reshape(bottom, top);
Expand All @@ -92,30 +91,6 @@ class Layer {
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {}

/**
* @brief Whether a layer should be shared by multiple nets during data
* parallelism. By default, all layers except for data layers should
* not be shared. data layers should be shared to ensure each worker
* solver access data sequentially during data parallelism.
*/
virtual inline bool ShareInParallel() const { return false; }

/** @brief Return whether this layer is actually shared by other nets.
* If ShareInParallel() is true and using more than one GPU and the
* net has TRAIN phase, then this function is expected return true.
*/
inline bool IsShared() const { return is_shared_; }

/** @brief Set whether this layer is actually shared by other nets
* If ShareInParallel() is true and using more than one GPU and the
* net has TRAIN phase, then is_shared should be set true.
*/
inline void SetShared(bool is_shared) {
CHECK(ShareInParallel() || !is_shared)
<< type() << "Layer does not support sharing.";
is_shared_ = is_shared;
}

/**
* @brief Adjust the shapes of top blobs and internal buffers to accommodate
* the shapes of the bottom blobs.
Expand Down Expand Up @@ -428,19 +403,6 @@ class Layer {
}

private:
/** Whether this layer is actually shared by other nets*/
bool is_shared_;

/** The mutex for sequential forward if this layer is shared */
shared_ptr<boost::mutex> forward_mutex_;

/** Initialize forward_mutex_ */
void InitMutex();
/** Lock forward_mutex_ if this layer is shared */
void Lock();
/** Unlock forward_mutex_ if this layer is shared */
void Unlock();

DISABLE_COPY_AND_ASSIGN(Layer);
}; // class Layer

Expand All @@ -450,8 +412,6 @@ class Layer {
template <typename Dtype>
inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Lock during forward to ensure sequential forward
Lock();
Dtype loss = 0;
Reshape(bottom, top);
switch (Caffe::mode()) {
Expand Down Expand Up @@ -482,7 +442,6 @@ inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
default:
LOG(FATAL) << "Unknown caffe mode.";
}
Unlock();
return loss;
}

Expand Down
5 changes: 3 additions & 2 deletions include/caffe/layers/base_data_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ class BasePrefetchingDataLayer :
const vector<Blob<Dtype>*>& top);

// Prefetches batches (asynchronously if to GPU memory)
static const int PREFETCH_COUNT = 3;
static const int PREFETCH_COUNT = 4; // same as proto

protected:
virtual void InternalThreadEntry();
virtual void load_batch(Batch<Dtype>* batch) = 0;

Batch<Dtype> prefetch_[PREFETCH_COUNT];
vector<shared_ptr<Batch<Dtype> > > prefetch_;
BlockingQueue<Batch<Dtype>*> prefetch_free_;
BlockingQueue<Batch<Dtype>*> prefetch_full_;
Batch<Dtype>* prefetch_current_;

Blob<Dtype> transformed_data_;
};
Expand Down
7 changes: 5 additions & 2 deletions include/caffe/layers/data_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/data_reader.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
Expand All @@ -29,9 +28,13 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
virtual inline int MaxTopBlobs() const { return 2; }

protected:
void Next();
bool Skip();
virtual void load_batch(Batch<Dtype>* batch);

DataReader reader_;
shared_ptr<db::DB> db_;
shared_ptr<db::Cursor> cursor_;
uint64_t offset_;
};

} // namespace caffe
Expand Down
6 changes: 5 additions & 1 deletion include/caffe/layers/hdf5_data_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ template <typename Dtype>
class HDF5DataLayer : public Layer<Dtype> {
public:
explicit HDF5DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param), offset_() {}
virtual ~HDF5DataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
Expand All @@ -38,6 +38,9 @@ class HDF5DataLayer : public Layer<Dtype> {
virtual inline int MinTopBlobs() const { return 1; }

protected:
void Next();
bool Skip();

virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
Expand All @@ -55,6 +58,7 @@ class HDF5DataLayer : public Layer<Dtype> {
std::vector<shared_ptr<Blob<Dtype> > > hdf_blobs_;
std::vector<unsigned int> data_permutation_;
std::vector<unsigned int> file_permutation_;
uint64_t offset_;
};

} // namespace caffe
Expand Down
Loading

0 comments on commit 3ba2054

Please sign in to comment.