Skip to content

Commit

Permalink
Merge branch 'master' of github.svail.baidu.com:baidu-research/DeepBe…
Browse files Browse the repository at this point in the history
…nch-internal into volta_updates
  • Loading branch information
Sharan Narang committed Nov 22, 2017
2 parents e80495c + 5abbfc6 commit c017457
Show file tree
Hide file tree
Showing 18 changed files with 276 additions and 84 deletions.
66 changes: 52 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,25 @@ version of the benchmark we will not be attempting to test these
methods.

In order to evaluate All-Reduce, we use the following libraries and benchmarks:
* [NVIDIA's NCCL](https://github.com/NVIDIA/nccl)
* [NVIDIA's NCCL](https://developer.nvidia.com/nccl)
* [Ohio State University (OSU) Benchmarks](http://mvapich.cse.ohio-state.edu/benchmarks/)
* [Baidu's Allreduce](https://github.com/baidu-research/baidu-allreduce/)

The NCCL library contains a set of standard communication
routines. The library supports any number of GPUs in a single node and
can be run in single process or multi-process (MPI). The NCCL routines
don't support All-Reduce across multiple nodes. In order to evaluate All-Reduce
across multiple nodes, we use the benchmarks from OSU. We report the
shortest latency achieved from all three implementations (NCCL single
process, NCCL MPI, OpenMPI).
The NCCL library can be build without MPI (for single node) and with MPI (for multinode) as shown in https://github.com/NVIDIA/nccl-tests.
We therefore have two versions of NCCL for the single node in the experiments. For multinode experiments,
we use only NCCL with MPI, the benchmark from OSU, and Baidu's Allreduce implementation.
We report the shortest latency achieved from all implementations for each configuration.

#### Topology for NVIDIA 8 GPU System
Each node has two CPU sockets, and each socket has a PCIe root complex. For each socket there are two PLX switches that are each connected to the CPU socket via 16 lanes of PCIe v3. There are two GPUs on each PLX switch. All pairs of GPUs communicate simultaneously over 16 lanes of PCIe v3. The two CPU sockets are connected via Intel QPI. The interconnect across nodes is InfiniBand FDR. The figure below shows a schematic diagram of one our nodes, where all devices connected by the same PCI
root complex are encapsulated in a dotted box
Each node has two CPU sockets (dual root topology), and each socket has a PCIe root complex. For each socket there are two PLX switches that are each connected to the CPU socket via 16 lanes of PCIe v3. There are two GPUs on each PLX switch. All pairs of GPUs communicate simultaneously over 16 lanes of PCIe v3. The two CPU sockets are connected via Intel QPI. The interconnect across nodes is InfiniBand FDR. The figure below shows a schematic diagram of one our nodes, where all devices connected by the same PCI
root complex are encapsulated in a dotted box. In our experiments, P100, TitanX Maxwell and M40 were such systems.

![Topology of NVIDIA GPU system with 8 GPUs](/doc/topology-8gpu-system.png)

#### Topology for NVIDIA 10 GPU System
Each node has one CPU socket (single root topology) with two PLX switches, each switch are connected to 5 GPUs. The communication among the GPUs in the same PLX switch traverses through the PLX switch only, whereas
the communication to any GPU connected to the other PLX switch requires traversal both PLX switches along with the connecting PCIe bridge. In our experiments, TitanX Pascal, and 1080Ti were such systems.

#### Topology for Intel Xeon Phi and Omni-Path System
The MPI_AllReduce time is measured on Intel Xeon Phi processor 7250 on Intel’s internal Endeavor cluster with Intel® Omni-Path Architecture (Intel® OPA) series 100 fabric with fat-tree topology, using Intel MPI 5.1.3.181.

Expand Down Expand Up @@ -489,10 +491,10 @@ In the results below, inputs and outputs are 16 bit but still use 32 bit compute

| Size (# of floats) | Number of Processors | Application | Time (ms) | Bandwidth (GB/s) | Processor |
|--------------------|----------------------|--------------------|-------------|------------------|----------------|
| 16777216 | 8 | Speech Recognition | 22.06 | 24.34 | TitanX Maxwell with InfiniBand FDR |
| 16777216 | 16 | Speech Recognition | 53.76 | 19.97 | Xeon Phi 7250 with Intel® Omni-Path |
| 16777216 | 32 | Speech Recognition | 55.68 | 38.57 | Xeon Phi 7250 with Intel® Omni-Path |

| 16777216 | 8 | Speech Recognition | 13.42 | 39.99 | TitanX Pascal with InfiniBand FDR |
| 16777216 | 16 | Speech Recognition | 46.53 | 23.08 | TitanX Maxwell with InfiniBand FDR |
| 16777216 | 32 | Speech Recognition | 49.54 | 43.35 | TitanX Maxwell with InfiniBand FDR |
| 64500000 | 32 | Speech Recognition | 97.34 | 84.82 | TitanX Pascal with InfiniBand FDR |

## Inference Server Results

Expand Down Expand Up @@ -678,6 +680,42 @@ The `osu_allreduce` benchmark can be run with more processes than
GPUs. However, all our experiments were conducted with each process
running on a single GPU.

# Baidu Benchmarks
## Compiling

In order to build the benchmarks, you will need to specify the following paths:
```
MPI_PATH: Path to MPI library. The benchmarks have been tested with OpenMPI version 2.0.1.
CUDA_PATH: Path to CUDA library. The benchmarks have been tested with version 8.0.61.
BAIDU_ALLREDUCE_PATH: Path to Baidu's allreduce implementation, which is avaiable at https://github.com/baidu-research/baidu-allreduce/.
```

To build all the benchmarks, please use the following command:
```
cd code/
make CUDA_PATH=<cuda_path> MPI_PATH=<mpi_path> BAIDU_ALLREDUCE_PATH=<baidu_allreduce_path>
```

Please set the ARCH paramter for appropriate architecture as discussed above in the NVIDIA Benchmarks section.

## Running the Benchmarks

Once compilation completes successfully, the executables will be
generated in the `bin` directory. Before executing the benchmarks, it
is important to set your `LD_LIBRARY_PATH` correctly. For bash shells,
please use:

```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<cuda_path>:<mpi_path>:<baidu_allreduce_path>
```

The Baidu All-Reduce benchmark can be run using `mpirun` as shown below:

```
mpirun -np <num_ranks> bin/ring_all_reduce
```
`num_ranks` is used as the total number of GPUs in the system.

# Intel Benchmarks
# Compiling and Running the Benchmarks

Expand Down
2 changes: 1 addition & 1 deletion code/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SUBDIRS= nvidia osu_allreduce
SUBDIRS= nvidia osu_allreduce baidu_allreduce

subdirs: $(SUBDIRS)

Expand Down
31 changes: 31 additions & 0 deletions code/baidu_allreduce/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

CC=mpic++
NVCC=nvcc
ARCH?=sm_52

CUDA_PATH?=/usr/local/cuda
CUDA_LIB64=$(CUDA_PATH)/lib64
MPI_PATH?=/usr/local/openmpi
BAIDU_ALLREDUCE_PATH?=/local/baidu-allreduce
BIN_DIR?=bin
MKDIR=mkdir -p
#BLAS
BLAS_LIBRARY?=cublas
BLAS_PATH?=$(CUDA_LIB64)
#CONV
KERNELS_DIR=../kernels/
COMMA=,
NVCC_ARCH_ARGS=$(foreach a,$(subst $(COMMA), ,$(ARCH)),--generate-code arch=$(patsubst sm_%,compute_%,$(a)),code=$(a))

.PHONY=all ring_all_reduce clean

ring_all_reduce:
$(MKDIR) $(BIN_DIR)
$(MPI_PATH)/bin/$(CC) -c -std=c++11 -I $(MPI_PATH)/include -I $(BAIDU_ALLREDUCE_PATH) -I $(CUDA_PATH)/include -I $(KERNELS_DIR) -DOMPI_SKIP_MPICXX= ring_all_reduce_mpi.cpp -o $(BIN_DIR)/ring_all_reduce_mpi.o
$(CUDA_PATH)/bin/$(NVCC) -c -std=c++11 -I $(MPI_PATH)/include -I $(BAIDU_ALLREDUCE_PATH) -I $(CUDA_PATH)/include -DOMPI_SKIP_MPICXX= $(BAIDU_ALLREDUCE_PATH)/collectives.cu -o $(BIN_DIR)/collectives.o
$(MPI_PATH)/bin/$(CC) -o $(BIN_DIR)/ring_all_reduce $(BIN_DIR)/ring_all_reduce_mpi.o $(BIN_DIR)/collectives.o -L$(CUDA_PATH)/lib64 -L$(MPI_PATH)/lib -lcudart -lmpi -DOMPI_SKIP_MPICXX=

clean:
rm -rf $(BIN_DIR)

rebuild: clean all
100 changes: 100 additions & 0 deletions code/baidu_allreduce/ring_all_reduce_mpi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include "collectives.h"

#include <iomanip>
#include <sstream>

#include <mpi.h>
#include <cuda_runtime.h>

#include <stdexcept>
#include <iostream>
#include <vector>
#include <algorithm>
#include <stdio.h>

#include "all_reduce_problems.h"
#include <chrono>


int main(int argc, char** argv, char** envp) {

int mpi_size, mpi_rank, mpi_local_rank;

char* env_str = std::getenv("OMPI_COMM_WORLD_LOCAL_RANK");
if(env_str == NULL) {
env_str = std::getenv("SLURM_LOCALID");
}

mpi_local_rank = std::stoi(std::string(env_str));

InitCollectives(mpi_local_rank);

MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Barrier(MPI_COMM_WORLD);

int64_t* sizes = all_reduce_kernels_size;
int64_t* numRepeats = all_reduce_kernels_repeat;

if (mpi_rank == 0) {
std::cout << " Ring AllReduce " << std::endl;
std::cout << " Num Ranks: " << mpi_size << std::endl;

std::cout << std::setfill('-') << std::setw(100) << "-" << std::endl;
std::cout << std::setfill(' ');
std::cout << " # of floats bytes transferred Avg Time (msec) Max Time (msec)" << std::endl;

std::cout << std::setfill('-') << std::setw(100) << "-" << std::endl;
std::cout << std::setfill(' ');
}

cudaError_t err;
for (int kernel_pos = 0; kernel_pos < _NUMBER_OF_KERNELS_; kernel_pos++) {
auto t_size = sizes[kernel_pos];

float* cpu_data = new float[t_size];
std::fill_n(cpu_data, t_size, 1.0f);

float* data;
err = cudaMalloc(&data, sizeof(float) * t_size);
if(err != cudaSuccess) { throw std::runtime_error("cudaMalloc failed!"); }

err = cudaMemcpy(data, cpu_data, sizeof(float) * t_size, cudaMemcpyHostToDevice);
if(err != cudaSuccess) { throw std::runtime_error("cudaMemcpy failed!"); }

float time_sum = 0;
for (int i = 0; i < numRepeats[kernel_pos]; i++) {

float* output;

auto start = std::chrono::steady_clock::now();
RingAllreduce(data, t_size, &output);
auto end = std::chrono::steady_clock::now();
time_sum += std::chrono::duration<double, std::milli>(end - start).count();

err = cudaFree(output);
if(err != cudaSuccess) { throw std::runtime_error("cudaFree failed!"); }
}

float time = static_cast<float>(time_sum / numRepeats[kernel_pos]);

float max_time, avg_time;
MPI_Reduce(&time, &max_time, 1, MPI_FLOAT, MPI_MAX, 0, MPI_COMM_WORLD);
MPI_Reduce(&time, &avg_time, 1, MPI_FLOAT, MPI_SUM, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);

if (mpi_rank == 0) {
avg_time = avg_time/mpi_size;
std::cout << std::setw(15) << t_size << std::setw(15) << t_size * 4 << std::setw(20) << avg_time << std::setw(20) << max_time << std::endl;
}

err = cudaFree(data);
if(err != cudaSuccess) { throw std::runtime_error("cudaFree failed!"); }

delete [] cpu_data;
}

MPI_Finalize();

return 0;
}
10 changes: 10 additions & 0 deletions code/kernels/all_reduce_problems.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef _ALL_REDUCE_KERNELS_
#define _ALL_REDUCE_KERNELS_


#define _NUMBER_OF_KERNELS_ 7

int64_t all_reduce_kernels_size[] = {100000, 3097600, 4194304, 6553600, 16777217, 38360000, 64500000};
int64_t all_reduce_kernels_repeat[] = {10000, 10000, 10000, 10000, 1000, 1000, 1000};

#endif
4 changes: 2 additions & 2 deletions code/nvidia/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ all_reduce: nccl_single nccl_mpi

nccl_single:
$(MKDIR) $(BIN_DIR)
$(CUDA_PATH)/bin/$(NVCC) nccl_single_all_reduce.cu -o $(BIN_DIR)/nccl_single_all_reduce -I $(NCCL_PATH)/build/include/ -I $(CUDNN_PATH)/include/ -L $(NCCL_PATH)/build/lib/ -L $(CUDNN_PATH)/lib64 -lnccl -lcudart -lcurand $(NVCC_ARCH_ARGS) -std=c++11
$(CUDA_PATH)/bin/$(NVCC) nccl_single_all_reduce.cu -o $(BIN_DIR)/nccl_single_all_reduce -I $(KERNELS_DIR) -I $(NCCL_PATH)/include/ -I $(CUDNN_PATH)/include/ -L $(NCCL_PATH)/lib/ -L $(CUDNN_PATH)/lib64 -lnccl -lcudart -lcurand $(NVCC_ARCH_ARGS) -std=c++11

nccl_mpi:
$(MKDIR) $(BIN_DIR)
$(CUDA_PATH)/bin/$(NVCC) nccl_mpi_all_reduce.cu -o $(BIN_DIR)/nccl_mpi_all_reduce -I $(NCCL_PATH)/build/include/ -I $(CUDNN_PATH)/include/ -I $(MPI_PATH)/include -L $(NCCL_PATH)/build/lib/ -L $(CUDNN_PATH)/lib64 -L $(MPI_PATH)/lib -lnccl -lcurand -lcudart -lmpi $(NVCC_ARCH_ARGS) -std=c++11
$(CUDA_PATH)/bin/$(NVCC) nccl_mpi_all_reduce.cu -o $(BIN_DIR)/nccl_mpi_all_reduce -I $(KERNELS_DIR) -I $(NCCL_PATH)/include/ -I $(CUDNN_PATH)/include/ -I $(MPI_PATH)/include -L $(NCCL_PATH)/lib/ -L $(CUDNN_PATH)/lib64 -L $(MPI_PATH)/lib -lnccl -lcurand -lcudart -lmpi $(NVCC_ARCH_ARGS) -std=c++11

sparse:
$(MKDIR) $(BIN_DIR)
Expand Down
31 changes: 21 additions & 10 deletions code/nvidia/nccl_mpi_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@

#include "tensor.h"
#include "nccl_helper.h"
#include "all_reduce_problems.h"

int main(int argc, char *argv[]) {

int size, rank;
int numRepeats = 1000;

if (argc > 1)
numRepeats = atoi(argv[1]);
int mpi_local_rank;

char* env_str = std::getenv("OMPI_COMM_WORLD_LOCAL_RANK");
if(env_str == NULL) {
env_str = std::getenv("SLURM_LOCALID");
}
mpi_local_rank = std::stoi(std::string(env_str));

//Initialize MPI
MPI_Init(&argc, &argv);
Expand All @@ -25,7 +30,7 @@ int main(int argc, char *argv[]) {
MPI_Barrier(MPI_COMM_WORLD);

// Set cuda devices
if (cudaSetDevice(rank) != cudaSuccess) {
if (cudaSetDevice(mpi_local_rank) != cudaSuccess) {
std::stringstream ss;
ss << "Failed to set cuda device. Rank: " << rank;
throw std::runtime_error(ss.str());
Expand All @@ -46,7 +51,9 @@ int main(int argc, char *argv[]) {
cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking);


std::vector<int> sizes = {100000, 3097600, 4194304, 6553600, 16777217};
int64_t* sizes = all_reduce_kernels_size;
int64_t* numRepeats = all_reduce_kernels_repeat;


if (rank == 0) {
std::cout << " NCCL MPI AllReduce " << std::endl;
Expand All @@ -60,13 +67,16 @@ int main(int argc, char *argv[]) {
std::cout << std::setfill(' ');

}

for (int kernel_pos = 0; kernel_pos < _NUMBER_OF_KERNELS_; kernel_pos++) {

for (auto &t_size: sizes) {
auto data = fill<float>({t_size*size}, rank);
auto t_size = sizes[kernel_pos];

auto data = fill<float>({(int)t_size*size}, rank);

cudaStreamSynchronize(stream);
auto start = std::chrono::steady_clock::now();
for (int i = 0; i < numRepeats; i++)
for (int i = 0; i < numRepeats[kernel_pos]; i++) {
CHECK_NCCL_ERROR(ncclAllReduce((void *) data.begin(),
(void *) (data.begin() + t_size),
t_size,
Expand All @@ -75,10 +85,11 @@ int main(int argc, char *argv[]) {
comm,
stream), rank);

cudaStreamSynchronize(stream);
cudaStreamSynchronize(stream);
}

auto end = std::chrono::steady_clock::now();
float time = static_cast<float>(std::chrono::duration<double, std::milli>(end - start).count() / numRepeats);
float time = static_cast<float>(std::chrono::duration<double, std::milli>(end - start).count() / numRepeats[kernel_pos]);

float max_time, avg_time;
MPI_Reduce(&time, &max_time, 1, MPI_FLOAT, MPI_MAX, 0, MPI_COMM_WORLD);
Expand Down
Loading

0 comments on commit c017457

Please sign in to comment.