diff --git a/Makefile b/Makefile index e794a95..91fed01 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,9 @@ ARCH= -gencode arch=compute_30,code=sm_30 \ -gencode arch=compute_52,code=[sm_52,compute_52] \ -gencode arch=compute_61,code=[sm_61,compute_61] +# GeForce RTX 2080 Ti, RTX 2080, RTX 2070, Quadro RTX 8000, Quadro RTX 6000, Quadro RTX 5000, Tesla T4, XNOR Tensor Cores +# ARCH= -gencode arch=compute_75,code=[sm_75,compute_75] + # Tesla V100 # ARCH= -gencode arch=compute_70,code=[sm_70,compute_70] @@ -48,7 +51,7 @@ COMMON= CFLAGS=-Wall -Wfatal-errors ifeq ($(DEBUG), 1) -OPTS=-O0 -g +OPTS= -Og -g endif ifeq ($(AVX), 1) diff --git a/README.md b/README.md index 02a50ba..4613c2e 100644 --- a/README.md +++ b/README.md @@ -16,16 +16,16 @@ How to compile: How to start: * Download [`yolov3.weights`](https://pjreddie.com/media/files/yolov3.weights) to the `bin` directory and run `./yolo.sh` on Linux (or `yolo_cpu.cmd` / `yolo_gpu.cmd` on Windows) -* Download [`yolov3-tiny.cfg`](https://pjreddie.com/media/files/yolov3-tiny.weights) to the `bin` directory and run `./tiny-yolo.sh` +* Download [`yolov3-tiny.weights`](https://pjreddie.com/media/files/yolov3-tiny.weights) to the `bin` directory and run `./tiny-yolo.sh` How to use **INT8**-inference: * Use flag `-quantized` at the end of command, for example, [`tiny-yolo-int8.sh`](https://github.com/AlexeyAB/yolo2_light/blob/master/bin/tiny-yolo-int8.sh) or [`yolo_cpu_int8.cmd`](https://github.com/AlexeyAB/yolo2_light/blob/master/bin/yolo_cpu_int8.cmd) * For the custom dataset, you should use `input_calibration=` parameter in your cfg-file, from the correspon cfg-file: [`yolov3-tiny.cfg`](https://github.com/AlexeyAB/yolo2_light/blob/29905072f194ee86fdeed6ff2d12fed818712411/bin/yolov3-tiny.cfg#L25) or [`yolov3.cfg`](https://github.com/AlexeyAB/yolo2_light/blob/29905072f194ee86fdeed6ff2d12fed818712411/bin/yolov3.cfg#L25), ... How to use **BIT1-XNOR**-inference - only for custom models (you should train it by yourself): -* You should base your cfg-file on [`tiny-yolo-obj_xnor.cfg`](https://github.com/AlexeyAB/yolo2_light/blob/master/bin/tiny-yolo-obj_xnor.cfg) and train it by using this repository as usual https://github.com/AlexeyAB/darknet +* You should base your cfg-file on [`yolov3-spp_xnor_obj.cfg`](https://github.com/AlexeyAB/darknet/files/2853459/yolov3-spp_xnor_obj.cfg.txt) and train it by using this repository as usual https://github.com/AlexeyAB/darknet by using pre-trained file [`darknet53_448_xnor.conv.74`](https://drive.google.com/open?id=1IT-vvyxRLlxY5g9rJp_G2U3TXYphjBv8) * Then use it for Detection-test or for getting Accuracy (mAP): - * `./darknet detector test data/obj.names tiny-yolo-obj_xnor.cfg data/tiny-yolo-obj_xnor_5000.weights -thresh 0.15 dog.jpg` - * `./darknet detector map data/obj.data tiny-yolo-obj_xnor.cfg data/tiny-yolo-obj_xnor_5000.weights -thresh 0.15` + * `./darknet detector test data/obj.names yolov3-spp_xnor_obj.cfg data/yolov3-spp_xnor_obj_5000.weights -thresh 0.15 dog.jpg` + * `./darknet detector map data/obj.data yolov3-spp_xnor_obj.cfg data/yolov3-spp_xnor_obj_5000.weights -thresh 0.15` Other models by the link: https://pjreddie.com/darknet/yolo/ diff --git a/bin/pthreadVC2.dll b/bin/pthreadVC2.dll new file mode 100644 index 0000000..165b4d2 Binary files /dev/null and b/bin/pthreadVC2.dll differ diff --git a/src/additionally.c b/src/additionally.c index c123a2c..3cb56f1 100644 --- a/src/additionally.c +++ b/src/additionally.c @@ -71,7 +71,7 @@ void yolov2_fuse_conv_batchnorm(network net) layer *l = &net.layers[j]; if (l->type == CONVOLUTIONAL) { - printf(" Fuse Convolutional layer \t\t l->size = %d \n", l->size); + //printf(" Fuse Convolutional layer \t\t l->size = %d \n", l->size); if (l->batch_normalize) { int f; @@ -103,7 +103,7 @@ void yolov2_fuse_conv_batchnorm(network net) } } else { - printf(" Skip layer: %d \n", l->type); + //printf(" Skip layer: %d \n", l->type); } } } @@ -214,33 +214,95 @@ void binary_align_weights(convolutional_layer *l) align_weights[i*new_lda + j] = l->binary_weights[i*k + j]; } } - float_to_bit(align_weights, l->align_bit_weights, align_weights_size); - l->mean_arr = calloc(l->n, sizeof(float)); - get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr); + + if (l->c % 32 == 0) + //if(gpu_index < 0 && l->stride == 1 && l->pad == 1 && l->c % 32 == 0) + //if (l->stride == 1 && l->pad == 1 && l->c % 32 == 0) + { + int fil, chan; + const int items_per_filter = l->c * l->size * l->size; + //const int dst_items_per_filter = new_lda; + for (fil = 0; fil < l->n; ++fil) + { + for (chan = 0; chan < l->c; chan += 32) + { + const int items_per_channel = l->size*l->size; + for (i = 0; i < items_per_channel; ++i) + { + uint32_t val = 0; + int c_pack; + for (c_pack = 0; c_pack < 32; ++c_pack) { + float src = l->binary_weights[fil*items_per_filter + (chan + c_pack)*items_per_channel + i]; + + //align_weights[fil*items_per_filter + chan*items_per_channel + i * 32 + c_pack] = src; + + align_weights[fil*new_lda + chan*items_per_channel + i * 32 + c_pack] = src; + //val |= (src << c); + } + + } + } + } + + //printf("\n l.index = %d \t aw[0] = %f, aw[1] = %f, aw[2] = %f, aw[3] = %f \n", l->index, align_weights[0], align_weights[1], align_weights[2], align_weights[3]); + //memcpy(l->binary_weights, align_weights, (l->size * l->size * l->c * l->n) * sizeof(float)); + + float_to_bit(align_weights, l->align_bit_weights, align_weights_size); + + //if (l->n >= 32) + if (gpu_index >= 0) + { + int M = l->n; + int N = l->out_w*l->out_h; + //printf("\n M = %d, N = %d, M %% 8 = %d, N %% 8 = %d - weights \n", M, N, M % 8, N % 8); + //printf("\n l.w = %d, l.c = %d, l.n = %d \n", l->w, l->c, l->n); + for (i = 0; i < align_weights_size / 8; ++i) l->align_bit_weights[i] = ~(l->align_bit_weights[i]); + } + + + + get_mean_array(l->binary_weights, m*k, l->n, l->mean_arr); + //get_mean_array(l->binary_weights, m*new_lda, l->n, l->mean_arr); + } + else { + float_to_bit(align_weights, l->align_bit_weights, align_weights_size); + + get_mean_array(l->binary_weights, m*k, l->n, l->mean_arr); + } + + //l->mean_arr = calloc(l->n, sizeof(float)); + + //get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr); + + + #ifdef GPU cudaError_t status; l->align_workspace_size = l->bit_align * l->size * l->size * l->c; status = cudaMalloc((void **)&l->align_workspace_gpu, l->align_workspace_size * sizeof(float)); status = cudaMalloc((void **)&l->transposed_align_workspace_gpu, l->align_workspace_size * sizeof(float)); - check_error(status); + CHECK_CUDA(status); //l->align_bit_weights_gpu = cuda_make_array(l->align_bit_weights, l->align_bit_weights_size * sizeof(char)/sizeof(float)); status = cudaMalloc((void **)&l->align_bit_weights_gpu, l->align_bit_weights_size); - check_error(status); + CHECK_CUDA(status); status = cudaMemcpy(l->align_bit_weights_gpu, l->align_bit_weights, l->align_bit_weights_size, cudaMemcpyHostToDevice); - check_error(status); + CHECK_CUDA(status); status = cudaMemcpy(l->binary_weights_gpu, l->binary_weights, m*k * sizeof(float), cudaMemcpyHostToDevice); - check_error(status); + CHECK_CUDA(status); - l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n); - cudaDeviceSynchronize(); + //l->mean_arr_gpu = cuda_make_array(l->mean_arr, l->n); + cuda_push_array(l->mean_arr_gpu, l->mean_arr, l->n); + CHECK_CUDA(cudaDeviceSynchronize()); #endif // GPU free(align_weights); } +void forward_blank_layer(layer l, network_state state) {} + void calculate_binary_weights(network net) { int j; @@ -252,13 +314,29 @@ void calculate_binary_weights(network net) if (l->xnor) { //printf("\n %d \n", j); - l->lda_align = 256; // 256bit for AVX2 + //l->lda_align = 256; // 256bit for AVX2 // set in make_convolutional_layer() + //if (l->size*l->size*l->c >= 2048) l->lda_align = 512; binary_align_weights(l); if (net.layers[j].use_bin_output) { - l->activation = LINEAR; + //l->activation = LINEAR; + } + +#ifdef GPU + // fuse conv_xnor + shortcut -> conv_xnor + if ((j + 1) < net.n && net.layers[j].type == CONVOLUTIONAL) { + layer *sc = &net.layers[j + 1]; + if (sc->type == SHORTCUT && sc->w == sc->out_w && sc->h == sc->out_h && sc->c == sc->out_c) + { + l->bin_conv_shortcut_in_gpu = net.layers[net.layers[j + 1].index].output_gpu; + l->bin_conv_shortcut_out_gpu = net.layers[j + 1].output_gpu; + + net.layers[j + 1].type = BLANK; + net.layers[j + 1].forward_gpu = forward_blank_layer; + } } +#endif // GPU } } } @@ -432,6 +510,148 @@ void transpose_bin(uint32_t *A, uint32_t *B, const int n, const int m, } } +// popcnt 32 bit +static inline int popcnt_32(uint32_t val32) { +#ifdef WIN32 // Windows MSVS + int tmp_count = __popcnt(val32); +#else // Linux GCC + int tmp_count = __builtin_popcount(val32); +#endif + return tmp_count; +} + +void gemm_nn_bin_transposed_32bit_packed(int M, int N, int K, float ALPHA, + uint32_t *A, int lda, + uint32_t *B, int ldb, + float *C, int ldc, float *mean_arr) +{ + int i; + #pragma omp parallel for + for (i = 0; i < M; ++i) { // l.n + int j, s; + float mean_val = mean_arr[i]; + for (j = 0; j < N; ++j) // out_h*out_w; + { + float val = 0; + for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c) + { + register uint32_t A_PART = ((uint32_t*)A)[i*lda + s]; + register uint32_t B_PART = ((uint32_t*)B)[j*ldb + s]; + uint32_t xnor_result = ~(A_PART ^ B_PART); + int32_t count = popcnt_32(xnor_result); // must be Signed int + + val += (2 * count - 32) * mean_val; + } + C[i*ldc + j] += val; + } + } +} + +// 32 channels -> 1 channel (with 32 floats) +// 256 channels -> 8 channels (with 32 floats) +void repack_input(float *input, float *re_packed_input, int w, int h, int c) +{ + const int items_per_channel = w * h; + int chan, i; + for (chan = 0; chan < c; chan += 32) + { + for (i = 0; i < items_per_channel; ++i) + { + int c_pack; + for (c_pack = 0; c_pack < 32; ++c_pack) { + float src = input[(chan + c_pack)*items_per_channel + i]; + + re_packed_input[chan*items_per_channel + i * 32 + c_pack] = src; + } + } + } +} + +// transpose uint32_t matrix +void transpose_uint32(uint32_t *src, uint32_t *dst, int src_h, int src_w, int src_align, int dst_align) +{ + //l.bit_align - algined (n) by 32 + //new_ldb - aligned (k) by 256 + + int i; + //#pragma omp parallel for + for (i = 0; i < src_h; i += 1) // l.size*l.size*l.c; + { + int j; + for (j = 0; j < src_w; j += 1) // out_h*out_w; + { + ((uint32_t *)dst)[j*dst_align / 32 + i] = ((uint32_t *)src)[i*src_align + j]; + } + } +} + +// convolution repacked bit matrix (32 channels -> 1 uint32_t) XNOR-net +void convolution_repacked(uint32_t *packed_input, uint32_t *packed_weights, float *output, + int w, int h, int c, int n, int size, int pad, int new_lda, float *mean_arr) +{ + int fil; + // filter index + #pragma omp parallel for + for (fil = 0; fil < n; ++fil) { + float mean_val = mean_arr[fil]; + int chan, c_pack, y, x, f_y, f_x; + // channel index + for (chan = 0; chan < c / 32; ++chan) + //for (chan = 0; chan < l.c; chan += 32) + //for (c_pack = 0; c_pack < 32; ++c_pack) + // input - y + for (y = 0; y < h; ++y) + // input - x + for (x = 0; x < w; ++x) + { + int const output_index = fil*w*h + y*w + x; + float sum = 0; + + // filter - y + for (f_y = 0; f_y < size; ++f_y) + { + int input_y = y + f_y - pad; + // filter - x + for (f_x = 0; f_x < size; ++f_x) + { + int input_x = x + f_x - pad; + if (input_y < 0 || input_x < 0 || input_y >= h || input_x >= w) continue; + + // normal + //float input = state.input[(chan + c_pack)*l.w*l.h + input_y*l.w + input_x]; + //float weight = l.weights[fil*l.c*l.size*l.size + (chan + c_pack)*l.size*l.size + f_y*l.size + f_x]; + + // packed + //float input = re_packed_input[chan*l.w*l.h + (input_y*l.w + input_x) * 32 + c_pack]; + //float weight = l.weights[fil*l.c*l.size*l.size + chan*l.size*l.size + (f_y*l.size + f_x) * 32 + c_pack]; + //sum += input * weight; + + //float input = re_packed_input[chan*l.w*l.h + (input_y*l.w + input_x) * 32 + c_pack]; + //float weight = l.weights[fil*l.c*l.size*l.size + chan*l.size*l.size + (f_y*l.size + f_x) * 32 + c_pack]; + //uint32_t bit1 = input > 0; + //uint32_t bit2 = weight > 0; + //uint32_t count = (~(bit1 ^ bit2)) & 1; + //float result = (2 * (float)count - 1) * mean_val; + //printf("\n mul = %f, bit1 = %d, bit2 = %d, count = %d, mean = %f, result = %f ", input*weight, bit1, bit2, count, mean_val, result); + //sum += result; + + uint32_t input = ((uint32_t *)packed_input)[chan*w*h + input_y*w + input_x]; + //uint32_t weight = ((uint32_t *)l.align_bit_weights)[fil*l.c*l.size*l.size/32 + chan*l.size*l.size + f_y*l.size + f_x]; + uint32_t weight = ((uint32_t *)packed_weights)[fil*new_lda / 32 + chan*size*size + f_y*size + f_x]; + + uint32_t xnor_result = ~(input ^ weight); + int32_t count = popcnt_32(xnor_result); // mandatory Signed int + sum += (2 * count - 32) * mean_val; + } + } + // l.output[filters][width][height] += + // state.input[channels][width][height] * + // l.weights[filters][channels][filter_width][filter_height]; + output[output_index] += sum; + } + } +} + // -------------- blas.c -------------- @@ -480,6 +700,64 @@ void gemm_nn(int M, int N, int K, float ALPHA, } } +void gemm_nn_bin_32bit_packed(int M, int N, int K, float ALPHA, + uint32_t *A, int lda, + uint32_t *B, int ldb, + float *C, int ldc, float *mean_arr) +{ + int i; + #pragma omp parallel for + for (i = 0; i < M; ++i) { // l.n + int j, s; + float mean_val = mean_arr[i]; + for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c) + { + register uint32_t A_PART = A[i*lda + s]; + __m256i a256 = _mm256_set1_epi32(A_PART); + + for (j = 0; j < N - 8; j += 8) + { + __m256i b256 = *((__m256i*)&B[s*ldb + j]); + __m256i xor256 = _mm256_xor_si256(a256, b256); // xnor = xor(a,b) + __m256i all_1 = _mm256_set1_epi8(255); + __m256i xnor256 = _mm256_andnot_si256(xor256, all_1); // xnor = not(xor(a,b)) + + // waiting for - CPUID Flags: AVX512VPOPCNTDQ: __m512i _mm512_popcnt_epi32(__m512i a) + __m256 count = _mm256_setr_ps( + popcnt_32(_mm256_extract_epi32(xnor256, 0)), + popcnt_32(_mm256_extract_epi32(xnor256, 1)), + popcnt_32(_mm256_extract_epi32(xnor256, 2)), + popcnt_32(_mm256_extract_epi32(xnor256, 3)), + popcnt_32(_mm256_extract_epi32(xnor256, 4)), + popcnt_32(_mm256_extract_epi32(xnor256, 5)), + popcnt_32(_mm256_extract_epi32(xnor256, 6)), + popcnt_32(_mm256_extract_epi32(xnor256, 7))); + + __m256 val2 = _mm256_set1_ps(2); + count = _mm256_mul_ps(count, val2); // count * 2 + + __m256 val32 = _mm256_set1_ps(32); + count = _mm256_sub_ps(count, val32); // count - 32 + + __m256 mean256 = _mm256_set1_ps(mean_val); + count = _mm256_mul_ps(count, mean256); // count * mean_val + + __m256 c256 = *((__m256*)&C[i*ldc + j]); + count = _mm256_add_ps(count, c256); // c = c + count + *((__m256*)&C[i*ldc + j]) = count; + } + + for (; j < N; ++j) // out_h*out_w; + { + register uint32_t B_PART = B[s*ldb + j]; + uint32_t xnor_result = ~(A_PART ^ B_PART); + int32_t count = popcnt_32(xnor_result); // must be Signed int + + C[i*ldc + j] += (2 * count - 32) * mean_val; + } + } + } +} #if defined(_MSC_VER) && _MSC_VER <= 1900 static inline __int32 _mm256_extract_epi64(__m256i a, const int index) { @@ -1008,6 +1286,31 @@ void gemm_nn(int M, int N, int K, float ALPHA, } +void gemm_nn_bin_32bit_packed(int M, int N, int K, float ALPHA, + uint32_t *A, int lda, + uint32_t *B, int ldb, + float *C, int ldc, float *mean_arr) +{ + int i; + #pragma omp parallel for + for (i = 0; i < M; ++i) { // l.n + int j, s; + float mean_val = mean_arr[i]; + for (s = 0; s < K; ++s) // l.size*l.size*l.c/32 or (l.size*l.size*l.c) + { + register uint32_t A_PART = A[i*lda + s]; + for (j = 0; j < N; ++j) // out_h*out_w; + { + register uint32_t B_PART = B[s*ldb + j]; + uint32_t xnor_result = ~(A_PART ^ B_PART); + int32_t count = popcnt_32(xnor_result); // must be Signed int + + C[i*ldc + j] += (2 * count - 32) * mean_val; + } + } + } +} + //From Berkeley Vision's Caffe! //https://github.com/BVLC/caffe/blob/master/LICENSE void im2col_cpu_custom(float* data_im, @@ -1764,6 +2067,10 @@ void free_network(network net) if (gpu_index >= 0) cuda_free(net.workspace); else free(net.workspace); if (net.input_state_gpu) cuda_free(net.input_state_gpu); + if (net.input_pinned_cpu) { // CPU + if (net.input_pinned_cpu_flag) cudaFreeHost(net.input_pinned_cpu); + else free(net.input_pinned_cpu); + } if (*net.input_gpu) cuda_free(*net.input_gpu); if (*net.truth_gpu) cuda_free(*net.truth_gpu); if (net.input_gpu) free(net.input_gpu); @@ -1899,6 +2206,12 @@ void free_layer(layer l) if (l.mean_arr) free(l.mean_arr); //if (l.weight_updates) free(l.weight_updates); //if (l.delta) free(l.delta); +#ifdef GPU + if (l.output && l.output_pinned) { + cudaFreeHost(l.output); + l.output = NULL; + } +#endif if (l.output) free(l.output); if (l.squared) free(l.squared); if (l.norms) free(l.norms); @@ -2206,6 +2519,13 @@ layer make_yolo_layer(int batch, int w, int h, int n, int total, int *mask, int #ifdef GPU l.output_gpu = cuda_make_array(l.output, batch*l.outputs); + + free(l.output); + if (cudaSuccess == cudaHostAlloc(&l.output, batch*l.outputs * sizeof(float), cudaHostRegisterMapped)) l.output_pinned = 1; + else { + cudaGetLastError(); // reset CUDA-error + l.output = calloc(batch*l.outputs, sizeof(float)); + } #endif fprintf(stderr, "yolo\n"); @@ -2367,7 +2687,12 @@ size_t get_workspace_size(layer l) { return most; } #endif - if (l.xnor) return (size_t)l.bit_align*l.size*l.size*l.c * sizeof(float); + if (l.xnor) { + size_t re_packed_input_size = l.c * l.w * l.h * sizeof(float); + size_t workspace_size = (size_t)l.bit_align*l.size*l.size*l.c * sizeof(float); + if (workspace_size < re_packed_input_size) workspace_size = re_packed_input_size; + return workspace_size; + } return (size_t)l.out_h*l.out_w*l.size*l.size*l.c * sizeof(float); } @@ -2442,6 +2767,18 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int int align = 32;// 8; int src_align = l.out_h*l.out_w; l.bit_align = src_align + (align - src_align % align); + + l.mean_arr = calloc(l.n, sizeof(float)); + + const size_t new_c = l.c / 32; + size_t in_re_packed_input_size = new_c * l.w * l.h + 1; + l.bin_re_packed_input = calloc(in_re_packed_input_size, sizeof(uint32_t)); + + l.lda_align = 256; // AVX2 + int k = l.size*l.size*l.c; + size_t k_aligned = k + (l.lda_align - k%l.lda_align); + size_t t_bit_input_size = k_aligned * l.bit_align / 8; + l.t_bit_input = calloc(t_bit_input_size, sizeof(char)); } if (batch_normalize) { @@ -2497,6 +2834,7 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int //} if (xnor) { l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size); + l.mean_arr_gpu = cuda_make_array(0, l.n); l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch); } @@ -3709,6 +4047,12 @@ network parse_network_cfg(char *filename, int batch, int quantized) net.workspace = cuda_make_array(0, (workspace_size - 1) / sizeof(float) + 1); int size = net.layers[0].inputs * net.batch; //get_network_input_size(net) * net.batch; net.input_state_gpu = cuda_make_array(0, size); + + if (cudaSuccess == cudaHostAlloc(&net.input_pinned_cpu, size * sizeof(float), cudaHostRegisterMapped)) net.input_pinned_cpu_flag = 1; + else { + cudaGetLastError(); // reset CUDA-error + net.input_pinned_cpu = calloc(size, sizeof(float)); + } } else { net.workspace = calloc(1, workspace_size); diff --git a/src/additionally.h b/src/additionally.h index daf461f..88802e5 100644 --- a/src/additionally.h +++ b/src/additionally.h @@ -66,7 +66,7 @@ extern "C" { // -------------- activations.h -------------- typedef enum { - LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN + LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU }ACTIVATION; static inline float stair_activate(float x) @@ -186,44 +186,7 @@ extern "C" { // -------------- XNOR-net GPU ------------ -#ifdef GPU - void swap_binary(convolutional_layer *l); - - void binarize_weights_gpu(float *weights, int n, int size, float *binary); - - void binarize_gpu(float *x, int n, float *binary); - - void im2col_align_ongpu(float *im, - int channels, int height, int width, - int ksize, int stride, int pad, float *data_col, int bit_align); - - void im2col_align_bin_ongpu(float *im, - int channels, int height, int width, - int ksize, int stride, int pad, float *data_col, int bit_align); - - void float_to_bit_gpu(float *src, unsigned char *dst, size_t size); - - void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m, - const int lda, const int ldb, const int block_size); - - void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size); - - //void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, - // unsigned char *A, int lda, - // unsigned char *B, int ldb, - // float *C, int ldc, float *mean_arr); - - void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, - unsigned char *A, int lda, - unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr, float *bias); - - void gemm_nn_custom_bin_mean_transposed_sequentially_gpu(int M, int N, int K, - unsigned char *A, int lda, - unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr); - -#endif // GPU + // in gpu.h // -------------- blas.h -------------- @@ -237,6 +200,29 @@ extern "C" { void transpose_bin(uint32_t *A, uint32_t *B, const int n, const int m, const int lda, const int ldb, const int block_size); + // 32 channels -> 1 channel (with 32 floats) + // 256 channels -> 8 channels (with 32 floats) + void repack_input(float *input, float *re_packed_input, int w, int h, int c); + + // transpose uint32_t matrix + void transpose_uint32(uint32_t *src, uint32_t *dst, int src_h, int src_w, int src_align, int dst_align); + + // convolution repacked bit matrix (32 channels -> 1 uint32_t) XNOR-net + void convolution_repacked(uint32_t *packed_input, uint32_t *packed_weights, float *output, + int w, int h, int c, int n, int size, int pad, int new_lda, float *mean_arr); + + // AVX2 + void gemm_nn_bin_32bit_packed(int M, int N, int K, float ALPHA, + uint32_t *A, int lda, + uint32_t *B, int ldb, + float *C, int ldc, float *mean_arr); + + // AVX2 + void gemm_nn_bin_transposed_32bit_packed(int M, int N, int K, float ALPHA, + uint32_t *A, int lda, + uint32_t *B, int ldb, + float *C, int ldc, float *mean_arr); + // AVX2 void im2col_cpu_custom(float* data_im, int channels, int height, int width, @@ -565,6 +551,7 @@ extern "C" { int * input_sizes; //float * delta; float * output; + int output_pinned; //float *output_multipler; float output_multipler; int8_t * output_int8; @@ -609,6 +596,8 @@ extern "C" { float *h_cpu; float *binary_input; + uint32_t *bin_re_packed_input; + char *t_bit_input; size_t workspace_size; @@ -632,6 +621,8 @@ extern "C" { float *binary_input_gpu; float *binary_weights_gpu; + float *bin_conv_shortcut_in_gpu; + float *bin_conv_shortcut_out_gpu; float * mean_gpu; float * variance_gpu; @@ -759,6 +750,8 @@ extern "C" { #ifdef GPU float *input_state_gpu; + float *input_pinned_cpu; + int input_pinned_cpu_flag; float **input_gpu; float **truth_gpu; diff --git a/src/gpu.cu b/src/gpu.cu index ef1d812..1772612 100644 --- a/src/gpu.cu +++ b/src/gpu.cu @@ -13,6 +13,26 @@ extern int gpu_index; #define BLOCK 512 +#define FULL_MASK 0xffffffff +#define WARP_SIZE 32 + +template +__device__ inline T1 __shfl_custom(T1 val, T2 lane) { +#if CUDART_VERSION >= 9000 + return __shfl_sync(FULL_MASK, val, lane); +#else + return __shfl(val, lane); +#endif +} + +template +__device__ inline uint32_t __ballot_custom(T val) { +#if CUDART_VERSION >= 9000 + return __ballot_sync(FULL_MASK, val); +#else + return __ballot(val); +#endif +} void pull_batchnorm_layer(layer l) {} // not required now void push_batchnorm_layer(layer l) {} // not required now @@ -21,31 +41,50 @@ void push_local_layer(local_layer l) {} // not required now void pull_connected_layer(local_layer l) {} // not required now void push_connected_layer(local_layer l) {} // not required now +int get_number_of_blocks(int array_size, int block_size) +{ + return array_size / block_size + ((array_size % block_size > 0) ? 1 : 0); +} void check_error(cudaError_t status) { - //cudaDeviceSynchronize(); cudaError_t status2 = cudaGetLastError(); if (status != cudaSuccess) { const char *s = cudaGetErrorString(status); char buffer[256]; printf("CUDA Error: %s\n", s); - assert(0); snprintf(buffer, 256, "CUDA Error: %s", s); +#ifdef WIN32 + getchar(); +#endif error(buffer); } if (status2 != cudaSuccess) { - const char *s = cudaGetErrorString(status); + const char *s = cudaGetErrorString(status2); char buffer[256]; printf("CUDA Error Prev: %s\n", s); - assert(0); snprintf(buffer, 256, "CUDA Error Prev: %s", s); +#ifdef WIN32 + getchar(); +#endif error(buffer); } } +void check_error_extended(cudaError_t status, const char *file, int line, const char *date_time) +{ + if (status != cudaSuccess) + printf("CUDA status Error: file: %s() : line: %d : build time: %s \n", file, line, date_time); +#ifdef DEBUG + status = cudaDeviceSynchronize(); + if (status != cudaSuccess) + printf("CUDA status = cudaDeviceSynchronize() Error: file: %s() : line: %d : build time: %s \n", file, line, date_time); +#endif + check_error(status); +} + void cuda_set_device(int n) { gpu_index = n; @@ -73,7 +112,52 @@ cudnnHandle_t cudnn_handle() } return handle[i]; } + +void cudnn_check_error(cudnnStatus_t status) +{ +#ifdef DEBUG + cudaDeviceSynchronize(); +#endif + cudnnStatus_t status2 = CUDNN_STATUS_SUCCESS; +#ifdef CUDNN_ERRQUERY_RAWCODE + cudnnStatus_t status_tmp = cudnnQueryRuntimeError(cudnn_handle(), &status2, CUDNN_ERRQUERY_RAWCODE, NULL); +#endif + if (status != CUDNN_STATUS_SUCCESS) + { + const char *s = cudnnGetErrorString(status); + char buffer[256]; + printf("cuDNN Error: %s\n", s); + snprintf(buffer, 256, "cuDNN Error: %s", s); +#ifdef WIN32 + getchar(); +#endif + error(buffer); + } + if (status2 != CUDNN_STATUS_SUCCESS) + { + const char *s = cudnnGetErrorString(status2); + char buffer[256]; + printf("cuDNN Error Prev: %s\n", s); + snprintf(buffer, 256, "cuDNN Error Prev: %s", s); +#ifdef WIN32 + getchar(); +#endif + error(buffer); + } +} + +void cudnn_check_error_extended(cudnnStatus_t status, const char *file, int line, const char *date_time) +{ + if (status != CUDNN_STATUS_SUCCESS) + printf("\n cuDNN status Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time); +#ifdef DEBUG + status = cudaDeviceSynchronize(); + if (status != CUDNN_STATUS_SUCCESS) + printf("\n cuDNN status = cudaDeviceSynchronize() Error in: file: %s() : line: %d : build time: %s \n", file, line, date_time); #endif + cudnn_check_error(status); +} +#endif // CUDNN float *cuda_make_array(float *x, size_t n) { @@ -362,14 +446,15 @@ __device__ float hardtan_activate_kernel(float x) return x; } __device__ float linear_activate_kernel(float x) { return x; } -__device__ float logistic_activate_kernel(float x) { return 1. / (1. + exp(-x)); } -__device__ float loggy_activate_kernel(float x) { return 2. / (1. + exp(-x)) - 1; } +__device__ float logistic_activate_kernel(float x) { return 1.f / (1.f + expf(-x)); } +__device__ float loggy_activate_kernel(float x) { return 2.f / (1.f + expf(-x)) - 1; } __device__ float relu_activate_kernel(float x) { return x*(x>0); } -__device__ float elu_activate_kernel(float x) { return (x >= 0)*x + (x < 0)*(exp(x) - 1); } -__device__ float relie_activate_kernel(float x) { return (x>0) ? x : .01*x; } -__device__ float ramp_activate_kernel(float x) { return x*(x>0) + .1*x; } -__device__ float leaky_activate_kernel(float x) { return (x>0) ? x : .1*x; } -__device__ float tanh_activate_kernel(float x) { return (2 / (1 + exp(-2 * x)) - 1); } +__device__ float elu_activate_kernel(float x) { return (x >= 0)*x + (x < 0)*(expf(x) - 1); } +__device__ float selu_activate_kernel(float x) { return (x >= 0)*1.0507f*x + (x < 0)*1.0507f*1.6732f*(expf(x) - 1); } +__device__ float relie_activate_kernel(float x) { return (x>0) ? x : .01f*x; } +__device__ float ramp_activate_kernel(float x) { return x*(x>0) + .1f*x; } +__device__ float leaky_activate_kernel(float x) { return (x>0) ? x : .1f*x; } +__device__ float tanh_activate_kernel(float x) { return (2 / (1 + expf(-2 * x)) - 1); } __device__ float plse_activate_kernel(float x) { if (x < -4) return .01 * (x + 4); @@ -417,6 +502,7 @@ __device__ float activate_kernel(float x, ACTIVATION a) return 0; } + __global__ void activate_array_kernel(float *x, int n, ACTIVATION a) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -427,16 +513,36 @@ __global__ void activate_array_leaky_kernel(float *x, int n) { int index = blockIdx.x*blockDim.x + threadIdx.x; if (index < n) { - float val = x[index]; - x[index] = (val > 0) ? val : val / 10; + x[index] = leaky_activate_kernel(x[index]); + } +} + +__global__ void activate_array_selu_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = selu_activate_kernel(x[index]); + } +} + +__global__ void activate_array_logistic_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = logistic_activate_kernel(x[index]); } } extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) { - if (a == LEAKY) activate_array_leaky_kernel << <(n / BLOCK + 1), BLOCK, 0, 0 >> >(x, n); - else activate_array_kernel << > >(x, n, a); - check_error(cudaPeekAtLastError()); + const int num_blocks = get_number_of_blocks(n, BLOCK); + if (a == LINEAR) return; + else if (a == LEAKY) activate_array_leaky_kernel << > >(x, n); + else if (a == LOGISTIC) activate_array_logistic_kernel << > >(x, n); + else if (a == SELU) activate_array_selu_kernel << > >(x, n); + else + activate_array_kernel << > >(x, n, a); + CHECK_CUDA(cudaPeekAtLastError()); } // softmax layer @@ -561,7 +667,15 @@ extern "C" void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY) // shortcut layer -__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) +__global__ void simple_input_shortcut_kernel(float *in, int size, float *add, float *out) +{ + int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (id >= size) return; + + out[id] = in[id] + add[id]; +} + +__global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) { int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; if (id >= size) return; @@ -575,11 +689,18 @@ __global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stri int out_index = i*sample + w2*(j*sample + h2*(k + c2*b)); int add_index = i*stride + w1*(j*stride + h1*(k + c1*b)); - out[out_index] += add[add_index]; + out[out_index] = in[out_index] + add[add_index]; } -extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) +extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out) { + if (w1 == w2 && h1 == h2 && c1 == c2) { + int size = batch * w1 * h1 * c1; + simple_input_shortcut_kernel << > >(in, size, add, out); + CHECK_CUDA(cudaPeekAtLastError()); + return; + } + int minw = (w1 < w2) ? w1 : w2; int minh = (h1 < h2) ? h1 : h2; int minc = (c1 < c2) ? c1 : c2; @@ -592,8 +713,8 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int if (sample < 1) sample = 1; int size = batch * minw * minh * minc; - shortcut_kernel << > >(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out); - check_error(cudaPeekAtLastError()); + input_shortcut_kernel << > >(in, size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out); + CHECK_CUDA(cudaPeekAtLastError()); } // ----------- Quantinization -------------- @@ -1185,101 +1306,262 @@ __device__ __host__ static inline ulonglong4 xnor_int256(ulonglong4 a, ulonglong return res; } -/* -// A (weights) in the shared_memory -__global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int K, + +__device__ __host__ static inline uint8_t xor_bit1(uint8_t a, uint8_t b) { + return (a^b) & 0b1; +} + +__device__ __host__ static inline uint32_t xor_int32(uint32_t a, uint32_t b) { + return (a^b); +} + +__device__ __host__ static inline uint64_t xor_int64(uint64_t a, uint64_t b) { + return (a^b); +} + +__device__ __host__ static inline uint4 xor_int128(uint4 a, uint4 b) { + uint4 res; + res.w = (a.w^b.w); + res.x = (a.x^b.x); + res.y = (a.y^b.y); + res.z = (a.z^b.z); + return res; +} + +__device__ __host__ static inline ulonglong4 xor_int256(ulonglong4 a, ulonglong4 b) { + ulonglong4 res; + res.w = (a.w^b.w); + res.x = (a.x^b.x); + res.y = (a.y^b.y); + res.z = (a.z^b.z); + return res; +} + + +__device__ static inline int popcnt_256(ulonglong4 a) { + return __popcll(a.w) + __popcll(a.x) + __popcll(a.y) + __popcll(a.z); +} + + + +// -------------------------------- +// -------------------------------- + +// -------------------------------- +// sequentially - B (input) in the shared_memory - BAD +// -------------------------------- +__global__ void gemm_nn_custom_bin_mean_transposed_sequentially_gpu_kernel(int M, int N, int K, unsigned char *A, int lda, unsigned char *B, int ldb, float *C, int ldc, float *mean_arr) { - int index = blockIdx.x*blockDim.x + threadIdx.x; + //__shared__ float mean_shared[32]; + //__shared__ uint32_t B_s[8192]; // 32 KB // [ldb x N`] // max = 262 144 bits + //__shared__ uint32_t B_s[4096]; // 16 KB // [ldb x N`] // max = 131 072 bits + __shared__ uint8_t B_s[4096 * 4]; // 16 KB // [ldb x N`] // max = 131 072 bits - __shared__ uint64_t A_s[6144]; // 48 KB // [lda x M`] - //__shared__ uint8_t A_s[6144*8]; // 48 KB // [lda x M`] - int start_i = blockIdx.x*blockDim.x / N; - int end_i = (blockIdx.x*blockDim.x + blockDim.x) / N + 1; + const int K_items = WARP_SIZE; + int start_j = blockIdx.x*blockDim.x / (K_items * M); - size_t shared_size = lda * (end_i - start_i); + { + int end_j = (blockIdx.x*blockDim.x + blockDim.x) / (K_items * M) + 1; + if (end_j > N) end_j = N; + size_t shared_size = ldb * (end_j - start_j); - int i_cur = index / N; - int local_i = i_cur - start_i; + if (shared_size != 0) { + //if(threadIdx.x == 0) printf(" start_j = %d, end_j = %d, shared_size = %d \n", start_j, end_j, shared_size); - for (int k = threadIdx.x * 64; k < shared_size; k += blockDim.x * 64) { - int x = start_i*lda + k; - if (x < (M*lda)) *((uint64_t *)(A_s + k / 8)) = *((uint64_t *)(A + x / 8)); + int k; + for (int k = threadIdx.x * 32; k < shared_size; k += blockDim.x * 32) { + int x = start_j*ldb + k; + if (x < (N*ldb)) *((uint32_t *)(B_s + k / 8)) = *((uint32_t *)(B + x / 8)); + } + } } + __syncthreads(); - //if (i_cur < M && (index % N == 0 || threadIdx.x == 0)) { - //for (int k = 0; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216] - //*((uint64_t *)(A_s + (local_i*lda + k) / 8)) = *((uint64_t *)(A + (i_cur*lda + k) / 8)); // weights - // } - //} + int index = blockIdx.x*blockDim.x + threadIdx.x; - __syncthreads(); + { + int i; // l.n + int j; // out_h*out_w + int k; // l.size * l.size * l.c - int i, j, k, h; + const int index2 = index / K_items; + i = index2 % M; // max M + j = index2 / M; // max N + int local_j = j - start_j; - j = index % N; - { // out_h*out_w - one channel output size [169 - 173056] - i = index / N; - if (i < M) // l.n - filters [16 - 55 - 1024] - { - float mean_val = mean_arr[i]; - int count = 0; + //if (i <= 1 && j <= 1 ) printf(" k = %d, K = %d, K_items = %d, i = %d, j = %d, lda = %d, ldb = %d, ldc = %d \n", + // k, K, K_items, i, j, lda, ldb, ldc); + { // l.n - filters [16 - 55 - 1024] + // further improvements: for (l.n == 1024) iterate several (j) - for (k = 0; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216] - //uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights - uint64_t a_bit64 = *((uint64_t *)(A_s + (local_i*lda + k) / 8)); // weights - uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); // input - uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64); - int tmp_count = __popcll(c_bit64); + if (j < N) + { // out_h*out_w - one channel output size [169 - 173056] - if (K - k < 64) tmp_count = tmp_count - (64 - (K - k)); // remove extra bits - count += tmp_count; - } + int count = 0; + + + const int bit_step = 32; + for (k = (threadIdx.x % WARP_SIZE) * bit_step; k < K; k += bit_step*WARP_SIZE) + { // l.size*l.size*l.c - one filter size [27 - 144 - 9216] + uint32_t a_bit32 = *((uint32_t *)(A + (i*lda + k) / 8)); // weights + //uint32_t b_bit32 = *((uint32_t *)(B + (j*ldb + k) / 8)); // input + uint32_t b_bit32 = *((uint32_t *)(B_s + (local_j*ldb + k) / 8)); // input + uint32_t c_bit32 = xnor_int32(a_bit32, b_bit32); + + count += __popc(c_bit32); + } - C[i*ldc + j] = (2 * count - K) * mean_val; + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) + count += __shfl_down(count, offset); + + + if (threadIdx.x % WARP_SIZE == 0) { + int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step)); + count = count - f1; + float mean_val = mean_arr[i]; + C[i*ldc + j] = (2 * count - K) * mean_val; + //B_s[threadIdx.x / WARP_SIZE] = (2 * count - K) * mean_val; + } + } } } } -#include - -void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, +// sequentially - BAD +void gemm_nn_custom_bin_mean_transposed_sequentially_gpu(int M, int N, int K, unsigned char *A, int lda, unsigned char *B, int ldb, float *C, int ldc, float *mean_arr) { - size_t size = M*N; + //size_t size = M*N; + size_t size = M*N * 32; + const int num_blocks = size / BLOCK + 1; - gemm_nn_custom_bin_mean_transposed_gpu_kernel << > >( + //printf(" K = %d \n", K); + + /* + printf("\n gemm_bin size = %d, num_blocks = %d, M*K = %d KB, N*K = %d KB \n (w) M*K/num_blocks = %d KB, (i) N*K/num_blocks = %d KB \n", + size, num_blocks, M*K / 1024, N*K / 1024, M*lda / num_blocks / 1024, N*ldb / num_blocks / 1024); + printf(" M / 512 = %d, N / 512 = %d, M*lda / 512 = %d, N*ldb / 512 = %d \n", M / 512, N / 512, M*lda/512, N*ldb/512); + */ + //printf(" shared_memory: (w) lda*BLOCK/N = %d, (i) ldb*BLOCK/M = %d, \t lda = %d \n\n", lda*BLOCK / N, ldb*BLOCK / M, lda); + + gemm_nn_custom_bin_mean_transposed_sequentially_gpu_kernel << > >( M, N, K, A, lda, B, ldb, C, ldc, mean_arr); } -*/ // -------------------------------- + +// 32 channels -> 1 channel (with 32 floats) +// 256 channels -> 8 channels (with 32 floats) +__global__ void repack_input_kernel_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c) +{ + //__shared__ uint32_t tmp[32]; + const int index = blockIdx.x*blockDim.x + threadIdx.x; + + const int global_warp_id = index / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + const int items_per_channel = w * h; + const int items_per_channel_aligned = items_per_channel + WARP_SIZE - (items_per_channel % WARP_SIZE); + + int i = 32 * (global_warp_id % (items_per_channel_aligned / WARP_SIZE)); + int chan = 32 * (global_warp_id / (items_per_channel_aligned / WARP_SIZE)); + + if (chan < c) + { + uint32_t result_bits = 0; + + for (int c_pack = 0; c_pack < 32; ++c_pack) + { + float src = 0; + if ((i + lane_id) < items_per_channel) { + src = input[(chan + c_pack)*items_per_channel + (i + lane_id)]; + } + uint32_t bit_mask = __ballot_custom(src > 0); + + uint32_t cur_bit = (bit_mask >> lane_id) & uint32_t(1); + + result_bits |= (cur_bit << c_pack); + } + if ((i + lane_id) < items_per_channel) { + re_packed_input_bin[chan*items_per_channel / 32 + (i + lane_id)] = result_bits; + } + } +} + +void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c) +{ + int size = (w * h * c) / 32 + 1; + const int block_size = BLOCK; + const int num_blocks = get_number_of_blocks(size, block_size); + //printf("\n num_blocks = %d, num_blocks/32 = %d, block_size = %d \n", num_blocks, num_blocks / 32, block_size); + repack_input_kernel_bin << > >(input, re_packed_input_bin, w, h, c); + CHECK_CUDA(cudaPeekAtLastError()); +} +// -------------------------------- + +__global__ void transpose_uint32_kernel(uint32_t *src, uint32_t *dst, int src_h, int src_w, int src_align, int dst_align) +{ + //l.bit_align - algined (n) by 32 + //new_ldb - aligned (k) by 256 + int index = blockIdx.x*blockDim.x + threadIdx.x; + + //for (i = 0; i < src_h; i += 1) + int i = index % src_h; // l.size*l.size*l.c; + { + //for (j = 0; j < src_w; j += 1) + int j = index / src_h; // out_h*out_w; + if (j < src_w) + { + ((uint32_t *)dst)[j*dst_align / 32 + i] = ((uint32_t *)src)[i*src_align + j]; + } + } +} + +void transpose_uint32_gpu(uint32_t *src, uint32_t *dst, int src_h, int src_w, int src_align, int dst_align) +{ + int size = src_w * src_h; + const int num_blocks = size / BLOCK + 1; + transpose_uint32_kernel << > >(src, dst, src_h, src_w, src_align, dst_align); + CHECK_CUDA(cudaPeekAtLastError()); +} +// -------------------------------- + + __inline__ __device__ int warpAllReduceSum(int val) { for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) +#if CUDART_VERSION >= 9000 + val += __shfl_xor_sync(FULL_MASK, val, mask); +#else val += __shfl_xor(val, mask); +#endif + return val; } +// -------------------------------- -// Coalesced memory access +// Coalescing // A (weights) in the shared_memory - GOOD __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int K, unsigned char *A, int lda, unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr, float *bias_arr) + float *C, int ldc, float *mean_arr, float *bias_arr, int leaky_activation, + float *shortcut_in_gpu, float *shortcut_out_gpu) { + // total 57% int index = blockIdx.x*blockDim.x + threadIdx.x; __shared__ uint8_t A_s[6144 * 8 / 4]; @@ -1293,7 +1575,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int int i_cur = index / N; int local_i = i_cur - start_i; - + // ~10% for (int k = threadIdx.x * 64; k < shared_size; k += blockDim.x * 64) { int x = start_i*lda + k; if (x < (M*lda)) *((uint64_t *)(A_s + k / 8)) = *((uint64_t *)(A + x / 8)); @@ -1301,7 +1583,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int __syncthreads(); int i, j, k, h; - + // 47% = 29 + 10 + 8 j = index % N; { // out_h*out_w - one channel output size [169 - 173056] i = index / N; @@ -1320,18 +1602,18 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int int64_t B_cur_index = (j*ldb + k) / 8; if (i >= M) A_cur_index = 0; - #pragma unroll +#pragma unroll for (int t = 0; t < WARP_SIZE; ++t) { const int lane_id = threadIdx.x % WARP_SIZE; - const int64_t A_i = __shfl(A_cur_index, t) + 32 * lane_id; - const int64_t B_i = __shfl(B_cur_index, t) + 32 * lane_id; + const int64_t A_i = __shfl_custom(A_cur_index, t) + 32 * lane_id; + const int64_t B_i = __shfl_custom(B_cur_index, t) + 32 * lane_id; { //ulonglong4 a_bit256 = *((ulonglong4 *)(A + A_i)); // weights ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + A_i)); // weights ulonglong4 b_bit256 = *((ulonglong4 *)(B + B_i)); // input - c_bit256 = xnor_int256(a_bit256, b_bit256); + c_bit256 = xor_int256(a_bit256, b_bit256); int tmp_count = __popcll(c_bit256.w) + __popcll(c_bit256.x) + __popcll(c_bit256.y) + __popcll(c_bit256.z); @@ -1342,8 +1624,9 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int } #endif + //#ifdef NOT_USED - // 32 thread X 64 bit = 2048 bit + // 32 thread X 64 bit = 2048 bit // 29% for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216] uint64_t c_bit64; @@ -1352,18 +1635,18 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int int64_t B_cur_index = (j*ldb + k) / 8; if (i >= M) A_cur_index = 0; - #pragma unroll +#pragma unroll for (int t = 0; t < WARP_SIZE; ++t) { const int lane_id = threadIdx.x % WARP_SIZE; - const int64_t A_i = __shfl(A_cur_index, t) + 8 * lane_id; - const int64_t B_i = __shfl(B_cur_index, t) + 8 * lane_id; + const int64_t A_i = __shfl_custom(A_cur_index, t) + 8 * lane_id; + const int64_t B_i = __shfl_custom(B_cur_index, t) + 8 * lane_id; { //uint64_t a_bit64 = *((uint64_t *)(A + A_i)); // weights uint64_t a_bit64 = *((uint64_t *)(A_s + A_i)); // weights uint64_t b_bit64 = *((uint64_t *)(B + B_i)); // input - c_bit64 = xnor_int64(a_bit64, b_bit64); + c_bit64 = xor_int64(a_bit64, b_bit64); int tmp_count = __popcll(c_bit64); int sum_count = warpAllReduceSum(tmp_count); @@ -1374,7 +1657,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int //#endif //#ifdef NOT_USED - // 32 thread X 32 bit = 1024 bit + // 32 thread X 32 bit = 1024 bit // 10% for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216] //int64_t A_cur_index = (i*lda + k) / 8; @@ -1382,18 +1665,18 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int int64_t B_cur_index = (j*ldb + k) / 8; if (i >= M) A_cur_index = 0; - #pragma unroll +#pragma unroll for (int t = 0; t < WARP_SIZE; ++t) { const int lane_id = threadIdx.x % WARP_SIZE; - const int64_t A_i = __shfl(A_cur_index, t) + 4 * lane_id; - const int64_t B_i = __shfl(B_cur_index, t) + 4 * lane_id; + const int64_t A_i = __shfl_custom(A_cur_index, t) + 4 * lane_id; + const int64_t B_i = __shfl_custom(B_cur_index, t) + 4 * lane_id; { //uint64_t a_bit64 = *((uint64_t *)(A + A_i)); // weights uint32_t a_bit32 = *((uint32_t *)(A_s + A_i)); // weights uint32_t b_bit32 = *((uint32_t *)(B + B_i)); // input - uint32_t c_bit32 = xnor_int32(a_bit32, b_bit32); + uint32_t c_bit32 = xor_int32(a_bit32, b_bit32); int tmp_count = __popc(c_bit32); int sum_count = warpAllReduceSum(tmp_count); @@ -1409,11 +1692,12 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int float bias_val = bias_arr[i]; //#ifdef NOT_USED + // 8% for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216] //ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights ulonglong4 a_bit256 = *((ulonglong4 *)(A_s + (local_i*lda + k) / 8)); // weights ulonglong4 b_bit256 = *((ulonglong4 *)(B + (j*ldb + k) / 8)); // input - ulonglong4 c_bit256 = xnor_int256(a_bit256, b_bit256); + ulonglong4 c_bit256 = xor_int256(a_bit256, b_bit256); count += __popcll(c_bit256.w) + __popcll(c_bit256.x) + __popcll(c_bit256.y) + __popcll(c_bit256.z); @@ -1425,7 +1709,7 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int //uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights uint64_t a_bit64 = *((uint64_t *)(A_s + (local_i*lda + k) / 8)); // weights uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); // input - uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64); + uint64_t c_bit64 = xor_int64(a_bit64, b_bit64); count += __popcll(c_bit64); } @@ -1434,287 +1718,321 @@ __global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int const int bit_step = 256; int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step)); count = count - f1; // remove extra bits (from empty space for align only) - - C[i*ldc + j] = (2 * count - K) *mean_val + bias_val; + float dst_val = (2 * count - K) *mean_val + bias_val; + if (leaky_activation) + dst_val = (dst_val >= 0) ? (dst_val) : (0.1f*dst_val); // Leaky activation + size_t out_index = i*ldc + j; + C[out_index] = dst_val; + + if (shortcut_out_gpu) { + shortcut_out_gpu[out_index] = shortcut_in_gpu[out_index] + dst_val; + } } } } } +// -------------------------------- +// src: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu +// You may also want to read: https://github.com/BVLC/caffe/blob/master/LICENSE -/* -// Coalescing -// B (input) in the shared_memory - GOOD -__global__ void gemm_nn_custom_bin_mean_transposed_gpu_kernel(int M, int N, int K, -unsigned char *A, int lda, -unsigned char *B, int ldb, -float *C, int ldc, float *mean_arr, float *bias_arr) -{ -int index = blockIdx.x*blockDim.x + threadIdx.x; - -__shared__ uint8_t B_s[4096*8]; // 32 KB // [ldb x N`] // max = 262 144 bits -//__shared__ uint64_t B_s[4096]; // 32 KB // [ldb x N`] // max = 262 144 bits - -int start_j = blockIdx.x*blockDim.x / M; -int end_j = (blockIdx.x*blockDim.x + blockDim.x) / M + 1; +__global__ void im2col_gpu_kernel(const int n, const float* data_im, + const int height, const int width, const int ksize, + const int pad, + const int stride, + const int height_col, const int width_col, + float *data_col) { + int index = blockIdx.x*blockDim.x + threadIdx.x; + for (; index < n; index += blockDim.x*gridDim.x) { + int w_out = index % width_col; + int h_index = index / width_col; + int h_out = h_index % height_col; + int channel_in = h_index / height_col; + int channel_out = channel_in * ksize * ksize; + int h_in = h_out * stride - pad; + int w_in = w_out * stride - pad; + float* data_col_ptr = data_col; + data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out; + const float* data_im_ptr = data_im; + data_im_ptr += (channel_in * height + h_in) * width + w_in; + for (int i = 0; i < ksize; ++i) { + for (int j = 0; j < ksize; ++j) { + int h = h_in + i; + int w = w_in + j; -size_t shared_size = ldb * (end_j - start_j); + *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ? + data_im_ptr[i * width + j] : 0; -int j_cur = index / M; -int local_j = j_cur - start_j; + //data_im[(channel_in * height + h_in) * width + w_in + i * width + j]; + //*data_col_ptr = data_im_ptr[ii * width + jj]; -for (int k = threadIdx.x * 256; k < shared_size; k += blockDim.x * 256) { -int x = start_j*ldb + k; -if (x < (N*ldb)) *((ulonglong4 *)(B_s + k / 8)) = *((ulonglong4 *)(B + x / 8)); + data_col_ptr += height_col * width_col; + } + } + } } -__syncthreads(); - -int i, j, k; - -i = index % M; // l.n - filters [16 - 55 - 1024] -{ -j = index / M; // out_h*out_w - one channel output size [169 - 173056] -if (j < N) -{ -int count = 0; -k = 0; - -//#ifdef NOT_USED -// 32 thread X 64 bit = 2048 bit -for (; k < (K - 2048); k += 2048) { // l.size*l.size*l.c - one filter size [27 - 9216] -uint64_t c_bit64; - -int64_t A_cur_index = (i*lda + k) / 8; -//int64_t B_cur_index = (j*ldb + k) / 8; -int64_t B_cur_index = (local_j*ldb + k) / 8; -if (i >= M) A_cur_index = 0; - -#pragma unroll -for (int t = 0; t < WARP_SIZE; ++t) { -const int lane_id = threadIdx.x % WARP_SIZE; - -const int64_t A_i = __shfl(A_cur_index, t) + 8 * lane_id; -const int64_t B_i = __shfl(B_cur_index, t) + 8 * lane_id; +void im2col_ongpu(float *im, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col) { -uint64_t a_bit64 = *((uint64_t *)(A + A_i)); // weights -//uint64_t b_bit64 = *((uint64_t *)(B + B_i)); // input -uint64_t b_bit64 = *((uint64_t *)(B_s + B_i)); // input -c_bit64 = xnor_int64(a_bit64, b_bit64); -int tmp_count = __popcll(c_bit64); + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; + int num_kernels = channels * height_col * width_col; + im2col_gpu_kernel << <(num_kernels + BLOCK - 1) / BLOCK, + BLOCK >> >( + num_kernels, im, height, width, ksize, pad, + stride, height_col, + width_col, data_col); -int sum_count = warpAllReduceSum(tmp_count); -if (lane_id == t) count += sum_count; -} -} + CHECK_CUDA(cudaPeekAtLastError()); } -//#endif +// -------------------------------- -//#ifdef NOT_USED -// 32 thread X 32 bit = 1024 bit -for (; k < (K - 1024); k += 1024) { // l.size*l.size*l.c - one filter size [27 - 9216] -int64_t A_cur_index = (i*lda + k) / 8; -//int64_t B_cur_index = (j*ldb + k) / 8; -int64_t B_cur_index = (local_j*ldb + k) / 8; -if (i >= M) A_cur_index = 0; +// Tensor Cores binary (CC >= 7.3 && CUDA >= 10.0) - __CUDA_SUBBYTE_IMMA__ +#if CUDART_VERSION >= 10000 +#include -#pragma unroll -for (int t = 0; t < WARP_SIZE; ++t) { -const int lane_id = threadIdx.x % WARP_SIZE; +#define WMMA_M 8 +#define WMMA_N 8 +#define WMMA_K 128 +#define WMMA_K32 (WMMA_K/32) -const int64_t A_i = __shfl(A_cur_index, t) + 4 * lane_id; -const int64_t B_i = __shfl(B_cur_index, t) + 4 * lane_id; +#define WMMA_Nx2 (WMMA_N*2) +// Tensor Cores are used for XOR-GEMM +__global__ void gemm_nn_custom_bin_mean_transposed_tensor_kernel(int M, int N, int K, + unsigned char *A, int lda, + unsigned char *B, int ldb, + float *C, int ldc, float *mean_arr, float *bias_arr, int leaky_activation, + float *shortcut_in_gpu, float *shortcut_out_gpu) { -uint32_t a_bit32 = *((uint32_t *)(A + A_i)); // weights -//uint32_t b_bit32 = *((uint32_t *)(B + B_i)); // input -uint32_t b_bit32 = *((uint32_t *)(B_s + B_i)); // input -uint32_t c_bit32 = xnor_int32(a_bit32, b_bit32); -int tmp_count = __popc(c_bit32); - -int sum_count = warpAllReduceSum(tmp_count); -if (lane_id == t) count += sum_count; -} -} -} -//#endif + // total 57% + int index = blockIdx.x*blockDim.x + threadIdx.x; -if (i < M) -{ -float mean_val = mean_arr[i]; -float bias_val = bias_arr[i]; + __shared__ int C_s[WMMA_N * WMMA_M * 32 * 2]; // 2 * 8 KB - Temprorary result of GEMM WMMA for 32 warps -//#ifdef NOT_USED -for (; k < K; k += 256) { // l.size*l.size*l.c - one filter size [27 - 144 - 9216] -ulonglong4 a_bit256 = *((ulonglong4 *)(A + (i*lda + k) / 8)); // weights -//ulonglong4 b_bit256 = *((ulonglong4 *)(B + (j*ldb + k) / 8)); // input -ulonglong4 b_bit256 = *((ulonglong4 *)(B_s + (local_j*ldb + k) / 8)); // input -ulonglong4 c_bit256 = xnor_int256(a_bit256, b_bit256); + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + const int global_warp_id = index / 32; -count += __popcll(c_bit256.w) + __popcll(c_bit256.x) + -__popcll(c_bit256.y) + __popcll(c_bit256.z); -} -//#endif + const int N_aligned = N + WMMA_Nx2 - (N % WMMA_Nx2); -#ifdef NOT_USED -for (; k < K; k += 64) { // l.size*l.size*l.c - one filter size [27 - 9216] -uint64_t a_bit64 = *((uint64_t *)(A + (i*lda + k) / 8)); // weights -//uint64_t b_bit64 = *((uint64_t *)(B + (j*ldb + k) / 8)); // input -uint64_t b_bit64 = *((uint64_t *)(B_s + (local_j*ldb + k) / 8)); // input -uint64_t c_bit64 = xnor_int64(a_bit64, b_bit64); + /* + __syncthreads(); + __shared__ uint32_t A_s[8 * 512]; // 8x512 = 8 x 16384 bits, instead of 8x4 + const int start_global_warp_id = blockIdx.x*blockDim.x / 32; + int start_i = start_global_warp_id / (N_aligned / WMMA_N); + start_i = start_i * WMMA_M; + if (start_i + WMMA_M > M) start_i = M - WMMA_M; // must be: i+7 < M + for (int tmp_index = threadIdx.x; tmp_index < (8 * 512); tmp_index += blockDim.x) + { + int k_tmp = tmp_index % 512; + int local_i = tmp_index / 512; -count += __popcll(c_bit64); -} -#endif + uint32_t a_val = ((uint32_t *)(A))[(start_i + local_i)*lda/32 + k_tmp]; + A_s[local_i * 512 + k_tmp] = a_val; + } + __syncthreads(); + */ -const int bit_step = 256; -int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step)); -count = count - f1; // remove extra bits (from empty space for align only) -C[i*ldc + j] = (2 * count - K) * mean_val + bias_val; -} -} -} -} -*/ + int i, j, k, h; + // 47% = 29 + 10 + 8 + j = global_warp_id % (N_aligned / WMMA_Nx2); + j = j * WMMA_Nx2; + { // out_h*out_w - one channel output size [169 - 173056] + i = global_warp_id / (N_aligned / WMMA_Nx2); + i = i * WMMA_M; -// Coalesced memory access - GOOD -void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, - unsigned char *A, int lda, - unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr, float *bias) -{ - size_t size = M*N; - const int num_blocks = size / BLOCK + 1; + int count = 0; + k = 0; - /* - printf("\n gemm_bin size = %d, num_blocks = %d, M*K = %d KB, N*K = %d KB \n (w) M*K/num_blocks = %d KB, (i) N*K/num_blocks = %d KB \n", - size, num_blocks, M*K / 1024, N*K / 1024, M*lda / num_blocks / 1024, N*ldb / num_blocks / 1024); - printf(" M / 512 = %d, N / 512 = %d, M*lda / 512 = %d, N*ldb / 512 = %d \n", M / 512, N / 512, M*lda/512, N*ldb/512); - */ - //printf(" shared_memory: (w) lda*BLOCK/N = %d, (i) ldb*BLOCK/M = %d, \t lda = %d \n\n", lda*BLOCK / N, ldb*BLOCK / M, lda); + if (i < M) //if (i < M) // l.n - filters [16 - 55 - 1024] + { + if (j + WMMA_Nx2 > N) j = N - WMMA_Nx2; // must be: j+7 < N + if (i + WMMA_M > M) i = M - WMMA_M; // must be: i+7 < M - gemm_nn_custom_bin_mean_transposed_gpu_kernel << > >( - M, N, K, - A, lda, - B, ldb, - C, ldc, - mean_arr, bias); -} -// -------------------------------- -// -------------------------------- +#if __CUDA_ARCH__ >= 730 + // Tensor Cores + using namespace nvcuda; -// -------------------------------- -// sequentially - B (input) in the shared_memory - BAD -// -------------------------------- -__global__ void gemm_nn_custom_bin_mean_transposed_sequentially_gpu_kernel(int M, int N, int K, - unsigned char *A, int lda, - unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr) -{ - //__shared__ float mean_shared[32]; - //__shared__ uint32_t B_s[8192]; // 32 KB // [ldb x N`] // max = 262 144 bits - //__shared__ uint32_t B_s[4096]; // 16 KB // [ldb x N`] // max = 131 072 bits - __shared__ uint8_t B_s[4096 * 4]; // 16 KB // [ldb x N`] // max = 131 072 bits + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c1_frag, c2_frag; + wmma::fill_fragment(c1_frag, 0); // !!!! XOR isn't XNOR !!!!!!!!!! + wmma::fill_fragment(c2_frag, 0); // !!!! XOR isn't XNOR !!!!!!!!!! + // 8 x 8 x 4 (uint32_t, 4 * 32 = 128 bit) + for (; k < K; k += 128) // l.size*l.size*l.c - one filter size [27 - 144 - 9216] + { + int64_t A_cur_index = (i*lda + k) / 8; // index in bits + int64_t B1_cur_index = (j*ldb + k) / 8; // index in bits + int64_t B2_cur_index = ((j + 8)*ldb + k) / 8; // index in bits - const int K_items = WARP_SIZE; - int start_j = blockIdx.x*blockDim.x / (K_items * M); + // try to use A that is cached in shared memory - poor performance + //if (i == start_i) wmma::load_matrix_sync(a_frag, &A_s[k / 32], (512 * 32)); // lda = (128*32) bits + //else wmma::load_matrix_sync(a_frag, (uint32_t *)(A + A_cur_index), lda); // lda = M - { - int end_j = (blockIdx.x*blockDim.x + blockDim.x) / (K_items * M) + 1; - if (end_j > N) end_j = N; - size_t shared_size = ldb * (end_j - start_j); + // lda, ldb - are in bits + wmma::load_matrix_sync(a_frag, (uint32_t *)(A + A_cur_index), lda); // lda = M - if (shared_size != 0) { - //if(threadIdx.x == 0) printf(" start_j = %d, end_j = %d, shared_size = %d \n", start_j, end_j, shared_size); + wmma::load_matrix_sync(b_frag, (uint32_t *)(B + B1_cur_index), ldb); // ldb = K + wmma::bmma_sync(c1_frag, a_frag, b_frag, c1_frag); // XOR-GEMM - int k; - for (int k = threadIdx.x * 32; k < shared_size; k += blockDim.x * 32) { - int x = start_j*ldb + k; - if (x < (N*ldb)) *((uint32_t *)(B_s + k / 8)) = *((uint32_t *)(B + x / 8)); + wmma::load_matrix_sync(b_frag, (uint32_t *)(B + B2_cur_index), ldb); // ldb = K + wmma::bmma_sync(c2_frag, a_frag, b_frag, c2_frag); // XOR-GEMM + } + // C[i*ldc + j] + wmma::store_matrix_sync(&C_s[warp_id*WMMA_M*WMMA_N], c1_frag, WMMA_N, wmma::mem_row_major); + wmma::store_matrix_sync(&C_s[warp_id*WMMA_M*WMMA_N + WMMA_M*WMMA_N * 32], c2_frag, WMMA_N, wmma::mem_row_major); +#else // __CUDA_ARCH__ >= 730 + + // Custom XOR-GEMM + int k_d = lane_id % 4; + int i_d = lane_id / 4; + int j_d = lane_id / 4; + + int32_t accum_c_val[8 * 2]; // wmma::fill_fragment(c_frag, 0); + for (int local_j = 0; local_j < 8 * 2; ++local_j) { + accum_c_val[local_j] = 0; } - } - } - __syncthreads(); - int index = blockIdx.x*blockDim.x + threadIdx.x; + // 8 x 8 x 4 (uint32_t, 4 * 32 = 128 bit) + for (; k < K; k += 128) // l.size*l.size*l.c - one filter size [27 - 144 - 9216] + { + int64_t A_cur_index = (i*lda + k) / 8; + //int64_t A_cur_index = (local_i*lda + k) / 8; + int64_t B_cur_index = (j*ldb + k) / 8; - { - int i; // l.n - int j; // out_h*out_w - int k; // l.size * l.size * l.c + // lda, ldb - are in bits + // 8*4 = 32 + // 8*8 = 64 + int k_d = lane_id % 4; + int i_d = lane_id / 4; + int j_d = lane_id / 4; + uint32_t a_val = *(uint32_t *)(A + ((i + i_d)*lda + (k + k_d * 32)) / 8); // wmma::load_matrix_sync(a_frag, (uint32_t *)(A + A_cur_index), lda); - const int index2 = index / K_items; - i = index2 % M; // max M - j = index2 / M; // max N - int local_j = j - start_j; + for (int c_x = 0; c_x < 2; c_x++) + { + uint32_t b_val = *(uint32_t *)(B + ((c_x * 8 + j + j_d)*ldb + (k + k_d * 32)) / 8); // wmma::load_matrix_sync(b_frag, (uint32_t *)(B + B_cur_index), ldb); - //if (i <= 1 && j <= 1 ) printf(" k = %d, K = %d, K_items = %d, i = %d, j = %d, lda = %d, ldb = %d, ldc = %d \n", - // k, K, K_items, i, j, lda, ldb, ldc); - { // l.n - filters [16 - 55 - 1024] - // further improvements: for (l.n == 1024) iterate several (j) + // wmma::bmma_sync(c_frag, a_frag, b_frag, c_frag); + int32_t c_val[8]; // 8 x 32 threads = 256 +#pragma UNROLL + for (int local_j = 0; local_j < 8; ++local_j) + { + uint32_t b_val_cur = __shfl_custom(b_val, local_j * 4 + k_d); + c_val[local_j] = __popc(xor_int32(a_val, b_val_cur)); + } +#pragma UNROLL + for (int local_j = 0; local_j < 8; ++local_j) + { +#pragma UNROLL + for (int local_k = 0; local_k < 4; ++local_k) { + accum_c_val[local_j + c_x * 8] += __shfl_custom(c_val[local_j], i_d * 4 + local_k); + } + } + } + } - if (j < N) - { // out_h*out_w - one channel output size [169 - 173056] + // only the first 8 threads (i) contain 8 good values each, in c_val[8] (j) = 8 x 8 =64 + // wmma::store_matrix_sync(&C_s[warp_id*WMMA_M*WMMA_N], c_frag, WMMA_N, wmma::mem_row_major); + if (k_d == 0) { + for (int c_x = 0; c_x < 2; c_x++) + { + for (int local_j = 0; local_j < 8; ++local_j) + { + C_s[warp_id*WMMA_M*WMMA_N + i_d*WMMA_N + local_j + WMMA_M*WMMA_N * 32 * c_x] = accum_c_val[local_j + c_x * 8]; + } + } + } +#endif // __CUDA_ARCH__ >= 730 - int count = 0; + for (int c_x = 0; c_x < 2; c_x++) + { + int j_d = lane_id % WMMA_N; + { +#pragma UNROLL + for (int i_d = lane_id / WMMA_N; i_d < WMMA_M; i_d += WMMA_M / 2) + { + int count = C_s[warp_id*WMMA_M*WMMA_N + i_d*WMMA_N + j_d + WMMA_M*WMMA_N * 32 * c_x]; + const int bit_step = 128; + int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step)); + count = count - f1; // remove extra bits (from empty space for align only) - const int bit_step = 32; - for (k = (threadIdx.x % WARP_SIZE) * bit_step; k < K; k += bit_step*WARP_SIZE) - { // l.size*l.size*l.c - one filter size [27 - 144 - 9216] - uint32_t a_bit32 = *((uint32_t *)(A + (i*lda + k) / 8)); // weights - //uint32_t b_bit32 = *((uint32_t *)(B + (j*ldb + k) / 8)); // input - uint32_t b_bit32 = *((uint32_t *)(B_s + (local_j*ldb + k) / 8)); // input - uint32_t c_bit32 = xnor_int32(a_bit32, b_bit32); + count = (2 * count - K); - count += __popc(c_bit32); - } + float mean_val = mean_arr[i + i_d]; + float bias_val = bias_arr[i + i_d]; + float dst_val = count *mean_val + bias_val; + if (leaky_activation) + dst_val = (dst_val >= 0) ? (dst_val) : (0.1f*dst_val); // Leaky activation - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) - count += __shfl_down(count, offset); + size_t out_index = (i + i_d)*ldc + (c_x * 8 + j + j_d); + C[out_index] = dst_val; + if (shortcut_out_gpu) { + shortcut_out_gpu[out_index] = shortcut_in_gpu[out_index] + dst_val; + } + } - if (threadIdx.x % WARP_SIZE == 0) { - int f1 = (K % bit_step == 0) ? 0 : (bit_step - (K % bit_step)); - count = count - f1; - float mean_val = mean_arr[i]; - C[i*ldc + j] = (2 * count - K) * mean_val; - //B_s[threadIdx.x / WARP_SIZE] = (2 * count - K) * mean_val; } } } } } +#endif // CUDART_VERSION >= 10000 +// -------------------------------- -// sequentially - BAD -void gemm_nn_custom_bin_mean_transposed_sequentially_gpu(int M, int N, int K, + + +// GOOD +void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, unsigned char *A, int lda, unsigned char *B, int ldb, - float *C, int ldc, float *mean_arr) + float *C, int ldc, float *mean_arr, float *bias, int leaky_activation, + float *shortcut_in_gpu, float *shortcut_out_gpu) { - //size_t size = M*N; - size_t size = M*N * 32; - - const int num_blocks = size / BLOCK + 1; + int size = M*N; + const int num_blocks = get_number_of_blocks(size, BLOCK); + //printf("\n M = %d, N = %d, M %% 8 = %d, N %% 8 = %d \n", M, N, M % 8, N % 8); - //printf(" K = %d \n", K); - - /* - printf("\n gemm_bin size = %d, num_blocks = %d, M*K = %d KB, N*K = %d KB \n (w) M*K/num_blocks = %d KB, (i) N*K/num_blocks = %d KB \n", - size, num_blocks, M*K / 1024, N*K / 1024, M*lda / num_blocks / 1024, N*ldb / num_blocks / 1024); - printf(" M / 512 = %d, N / 512 = %d, M*lda / 512 = %d, N*ldb / 512 = %d \n", M / 512, N / 512, M*lda/512, N*ldb/512); - */ - //printf(" shared_memory: (w) lda*BLOCK/N = %d, (i) ldb*BLOCK/M = %d, \t lda = %d \n\n", lda*BLOCK / N, ldb*BLOCK / M, lda); - - gemm_nn_custom_bin_mean_transposed_sequentially_gpu_kernel << > >( - M, N, K, - A, lda, - B, ldb, - C, ldc, - mean_arr); + //if (M >= 32) // l.n >= 32 +#if CUDART_VERSION >= 10000 + if (1) + { + const int M_aligned = M + (8 - (M % 8)); + const int N_aligned = N + (16 - (N % 16)); + int size = (M_aligned / 8)*(N_aligned / 16)*WARP_SIZE; + const int num_blocks = get_number_of_blocks(size, BLOCK); + + //printf(" lda = %d, ldb = %d, ldc = %d, lda/32 = %d, ldb/32 = %d, ldc/32 = %d \n", lda, ldb, ldc, lda / 32, ldb / 32, ldc / 32); + //printf(" l.c (K/9) = %d, M (l.n) = %d \n", (K%9 == 0)? K / 9: K, M); + gemm_nn_custom_bin_mean_transposed_tensor_kernel << > > ( + M, N, K, + A, lda, + B, ldb, + C, ldc, + mean_arr, bias, leaky_activation, + shortcut_in_gpu, shortcut_out_gpu); + } + else +#endif //# CUDART_VERSION >= 10000 + { + gemm_nn_custom_bin_mean_transposed_gpu_kernel << > > ( + M, N, K, + A, lda, + B, ldb, + C, ldc, + mean_arr, bias, leaky_activation, + shortcut_in_gpu, shortcut_out_gpu); + } + CHECK_CUDA(cudaPeekAtLastError()); } -// -------------------------------- +// -------------------------------- \ No newline at end of file diff --git a/src/gpu.h b/src/gpu.h index 424ef32..ee7cf3b 100644 --- a/src/gpu.h +++ b/src/gpu.h @@ -2,6 +2,27 @@ #ifndef GPU_H #define GPU_H +#ifndef __DATE__ +#define __DATE__ +#endif + +#ifndef __TIME__ +#define __TIME__ +#endif + +#ifndef __FUNCTION__ +#define __FUNCTION__ +#endif + +#ifndef __LINE__ +#define __LINE__ 0 +#endif + +#ifndef __FILE__ +#define __FILE__ +#endif + + #ifdef __cplusplus extern "C" { #endif @@ -9,6 +30,10 @@ extern "C" { #ifdef GPU + void check_error(cudaError_t status); + void check_error_extended(cudaError_t status, const char *file, int line, const char *date_time); +#define CHECK_CUDA(X) check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ ); + struct layer; typedef struct layer layer; typedef struct layer local_layer; @@ -27,6 +52,9 @@ extern "C" { #ifdef CUDNN cudnnHandle_t cudnn_handle(); + + void cudnn_check_error_extended(cudnnStatus_t status, const char *file, int line, const char *date_time); +#define CHECK_CUDNN(X) cudnn_check_error_extended(X, __FILE__ " : " __FUNCTION__, __LINE__, __DATE__ " - " __TIME__ ); #endif float *cuda_make_array(float *x, size_t n); @@ -71,11 +99,11 @@ extern "C" { // reorg layer void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out); - + // upsample layer void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out); - void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); + void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out); // -------------------- Quantinization ------------------- void cuda_convert_f32_to_int8(float* input_f32, size_t size, int8_t *output_int8, float multipler, int max_val); @@ -86,6 +114,43 @@ extern "C" { void cuda_do_multiply_f32(float *input_output, size_t size, float multipler); + // -------------------- XNOR ------------------- + + void swap_binary(convolutional_layer *l); + + void binarize_weights_gpu(float *weights, int n, int size, float *binary); + + void binarize_gpu(float *x, int n, float *binary); + + void repack_input_gpu_bin(float *input, uint32_t *re_packed_input_bin, int w, int h, int c); + + void transpose_uint32_gpu(uint32_t *src, uint32_t *dst, int src_h, int src_w, int src_align, int dst_align); + + void im2col_ongpu(float *im, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col); + + void im2col_align_ongpu(float *im, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col, int bit_align); + + void im2col_align_bin_ongpu(float *im, + int channels, int height, int width, + int ksize, int stride, int pad, float *data_col, int bit_align); + + void float_to_bit_gpu(float *src, unsigned char *dst, size_t size); + + void transpose_bin_gpu(unsigned char *A, unsigned char *B, const int n, const int m, + const int lda, const int ldb, const int block_size); + + void fill_int8_gpu(unsigned char *src, unsigned char val, size_t size); + + void gemm_nn_custom_bin_mean_transposed_gpu(int M, int N, int K, + unsigned char *A, int lda, + unsigned char *B, int ldb, + float *C, int ldc, float *mean_arr, float *bias, int leaky_activation, + float *shortcut_in_gpu, float *shortcut_out_gpu); + #endif // GPU #ifdef __cplusplus diff --git a/src/main.c b/src/main.c index d67016e..3980ea0 100644 --- a/src/main.c +++ b/src/main.c @@ -90,9 +90,9 @@ void draw_detections_v3(image im, detection *dets, int num, float thresh, char * printf("%s: %.0f%%", names[best_class], selected_detections[i].det.prob[best_class] * 100); if (ext_output) printf("\t(left_x: %4.0f top_y: %4.0f width: %4.0f height: %4.0f)\n", - (selected_detections[i].det.bbox.x - selected_detections[i].det.bbox.w / 2)*im.w, - (selected_detections[i].det.bbox.y - selected_detections[i].det.bbox.h / 2)*im.h, - selected_detections[i].det.bbox.w*im.w, selected_detections[i].det.bbox.h*im.h); + round((selected_detections[i].det.bbox.x - selected_detections[i].det.bbox.w / 2)*im.w), + round((selected_detections[i].det.bbox.y - selected_detections[i].det.bbox.h / 2)*im.h), + round(selected_detections[i].det.bbox.w*im.w), round(selected_detections[i].det.bbox.h*im.h)); else printf("\n"); int j; @@ -198,7 +198,7 @@ void test_detector_cpu(char **names, char *cfgfile, char *weightfile, char *file //network_predict(net, X); #ifdef GPU if (quantized) { - network_predict_gpu_cudnn_quantized(net, X); // quantized works only with Yolo v2 + network_predict_gpu_cudnn_quantized(net, X); // quantized //nms = 0.2; } else { @@ -209,7 +209,7 @@ void test_detector_cpu(char **names, char *cfgfile, char *weightfile, char *file network_predict_opencl(net, X); #else if (quantized) { - network_predict_quantized(net, X); // quantized works only with Yolo v2 + network_predict_quantized(net, X); // quantized nms = 0.2; } else { @@ -393,7 +393,7 @@ static void *detect_in_thread(void *ptr) //float *prediction = network_predict(net, X); #ifdef GPU if (demo_quantized) { - network_predict_gpu_cudnn_quantized(net, X); // quantized works only with Yolo v2 + network_predict_gpu_cudnn_quantized(net, X); // quantized //nms = 0.2; } else { @@ -404,7 +404,7 @@ static void *detect_in_thread(void *ptr) network_predict_opencl(net, X); #else if (demo_quantized) { - network_predict_quantized(net, X); // quantized works only with Yolo v2 + network_predict_quantized(net, X); // quantized nms = 0.2; } else { diff --git a/src/yolov2_forward_network.c b/src/yolov2_forward_network.c index 3414c41..39b2887 100644 --- a/src/yolov2_forward_network.c +++ b/src/yolov2_forward_network.c @@ -113,23 +113,93 @@ void forward_convolutional_layer_cpu(layer l, network_state state) //im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); // AVX2 // XNOR-net - bit-1: weights, input, calculation - if (l.xnor && (l.stride == 1 && l.pad == 1)) { + if (l.xnor && l.align_bit_weights && (l.stride == 1 && l.pad == 1)) + { memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float)); + + if (l.c % 32 == 0) + { + //printf(" l.index = %d - new XNOR \n", l.index); + + int ldb_align = l.lda_align; + size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; + size_t t_intput_size = new_ldb * l.bit_align;// n; + size_t t_bit_input_size = t_intput_size / 8;// +1; + + const int new_c = l.c / 32; + + float *re_packed_input = calloc(l.c * l.w * l.h, sizeof(float)); + uint32_t *bin_re_packed_input = calloc(new_c * l.w * l.h + 1, sizeof(uint32_t)); + + // float32x4 by channel (as in cuDNN) + repack_input(state.input, re_packed_input, l.w, l.h, l.c); + + // 32 x floats -> 1 x uint32_t + float_to_bit(re_packed_input, (char *)bin_re_packed_input, l.c * l.w * l.h); + + free(re_packed_input); + + // slow - convolution the packed inputs and weights: float x 32 by channel (as in cuDNN) + //convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output, + // l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr); + + // // then exit from if() + + + im2col_cpu_custom((float *)bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, b); + //im2col_cpu((float *)bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, b); + + free(bin_re_packed_input); + + int new_k = l.size*l.size*l.c / 32; + + // good for (l.c == 64) + //gemm_nn_bin_32bit_packed(m, n, new_k, 1, + // l.align_bit_weights, l.new_lda/32, + // b, n, + // c, n, l.mean_arr); + + // // then exit from if() + + + //size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; + //size_t t_intput_size = new_ldb * l.bit_align;// n; + //size_t t_bit_input_size = t_intput_size / 8;// +1; + + char *t_bit_input = calloc(t_bit_input_size, sizeof(char)); + + transpose_uint32((uint32_t *)b, t_bit_input, new_k, n, n, new_ldb); + + // the main GEMM function + gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr); + + // // alternative GEMM + //gemm_nn_bin_transposed_32bit_packed(m, n, new_k, 1, + // l.align_bit_weights, l.new_lda/32, + // t_bit_input, new_ldb / 32, + // c, n, l.mean_arr); + + free(t_bit_input); + + } + else { // else (l.c % 32 != 0) + //im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); - im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); + im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); - int ldb_align = l.lda_align; - size_t new_ldb = k + (ldb_align - k%ldb_align); - char *t_bit_input = NULL; - size_t t_intput_size = binary_transpose_align_input(k, n, b, &t_bit_input, ldb_align, l.bit_align); + int ldb_align = l.lda_align; + size_t new_ldb = k + (ldb_align - k%ldb_align); + char *t_bit_input = NULL; + size_t t_intput_size = binary_transpose_align_input(k, n, b, &t_bit_input, ldb_align, l.bit_align); - // 5x times faster than gemm()-float32 - gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr); + // 5x times faster than gemm()-float32 + gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr); - //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr); + //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr); - //free(t_input); - free(t_bit_input); + //free(t_input); + free(t_bit_input); + } } else { im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); // AVX2 @@ -170,12 +240,15 @@ void forward_convolutional_layer_cpu(layer l, network_state state) // 3. Add BIAS //if (l.batch_normalize) - for (int b=0; b= 256 && l.size > 1) + if (l.align_bit_weights_gpu && l.c >= 32) { - cudaError_t status = cudaSuccess; - int input_size = l.c*l.h*l.w*l.batch; - - int m = l.n; - int k = l.size*l.size*l.c; - int n = l.out_w*l.out_h; - //float * a = l.binary_weights_gpu; - - int ldb_align = l.lda_align; - size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; - size_t t_intput_size = new_ldb * n; - size_t t_bit_input_size = t_intput_size / 8;// +1; - - { - int i = 0; - if (l.stride == 1 && l.c >= 256 && l.w >= 13 && l.size > 1 && 0) // disable + //return; + cudaError_t status = cudaSuccess; + int input_size = l.c*l.h*l.w*l.batch; + + int m = l.n; + int k = l.size*l.size*l.c; + int n = l.out_w*l.out_h; + //float * a = l.weights_gpu; + + int ldb_align = l.lda_align; + size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; + size_t t_intput_size = new_ldb * n; + size_t t_bit_input_size = t_intput_size / 8;// +1; + + if (l.c % 32 == 0) { - // stride=1 only - im2col_align_bin_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align); - //cudaDeviceSynchronize(); + //printf("\n\n l.index = %d, l.w = %d, l.c = %d, l.n = %d, l.stride = %d, l.pad = %d - new XNOR \n", l.index, l.w, l.c, l.n, l.stride, l.pad); + //printf("l.align_workspace_size = %d, (l.c * l.w * l.h) = %d \n", l.align_workspace_size, (l.c * l.w * l.h)); + + int ldb_align = l.lda_align; + size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; + size_t t_intput_size = new_ldb * l.bit_align;// n; + size_t t_bit_input_size = t_intput_size / 8;// +1; + + const int new_c = l.c / 32; + + repack_input_gpu_bin(state.input, (uint32_t *)l.align_workspace_gpu, l.w, l.h, l.c); + + im2col_ongpu(l.align_workspace_gpu, new_c, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + + int new_k = l.size*l.size*l.c / 32; + + transpose_uint32_gpu((uint32_t *)state.workspace, (uint32_t *)l.transposed_align_workspace_gpu, new_k, n, n, new_ldb); + + gemm_nn_custom_bin_mean_transposed_gpu(m, n, k, + (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, + new_ldb, l.output_gpu, n, l.mean_arr_gpu, l.biases_gpu, l.activation == LEAKY, + l.bin_conv_shortcut_in_gpu, l.bin_conv_shortcut_out_gpu); + } else { + //printf("\n\n l.index = %d, l.w = %d, l.c = %d, l.n = %d, l.stride = %d, l.pad = %d - old XNOR \n", l.index, l.w, l.c, l.n, l.stride, l.pad); + + int i = 0; + im2col_align_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.align_workspace_gpu, l.bit_align); - //cudaDeviceSynchronize(); - // should be optimized float_to_bit_gpu(l.align_workspace_gpu, (unsigned char *)state.workspace, l.align_workspace_size); - //cudaDeviceSynchronize(); - } - transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8); - //cudaDeviceSynchronize(); + transpose_bin_gpu((unsigned char *)state.workspace, (unsigned char *)l.transposed_align_workspace_gpu, k, n, l.bit_align, new_ldb, 8); - // should be optimized - gemm_nn_custom_bin_mean_transposed_gpu(m, n, k, - (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu, l.biases_gpu); + gemm_nn_custom_bin_mean_transposed_gpu(m, n, k, + (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, + new_ldb, l.output_gpu, n, l.mean_arr_gpu, l.biases_gpu, l.activation == LEAKY, + l.bin_conv_shortcut_in_gpu, l.bin_conv_shortcut_out_gpu); - //gemm_nn_custom_bin_mean_transposed_sequentially_gpu(m, n, k, - // (unsigned char *)l.align_bit_weights_gpu, new_ldb, (unsigned char *)l.transposed_align_workspace_gpu, new_ldb, l.output_gpu, n, l.mean_arr_gpu); + } - //cudaDeviceSynchronize(); - //check_error(status); + //add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); + if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); + return; } - //add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); - if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); - //cudaDeviceSynchronize(); - return; - } - if (!l.align_bit_weights_gpu) { binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu); } @@ -353,8 +366,10 @@ void forward_upsample_layer_cuda(const layer l, network_state state) // shortcut_layer.c void forward_shortcut_layer_cuda(const layer l, network_state state) { - copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); - shortcut_gpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + //copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1); + //shortcut_gpu(l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); + //activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); + input_shortcut_gpu(state.input, l.batch, l.w, l.h, l.c, state.net.layers[l.index].output_gpu, l.out_w, l.out_h, l.out_c, l.output_gpu); activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); } @@ -465,6 +480,9 @@ void forward_network_gpu_cudnn(network net, network_state state) forward_region_layer_gpu_cuda(l, state); //printf("\n REGION \n"); } + else if (l.type == BLANK) { + //printf("\n layer: BLANK - %d \n", i); + } else { printf("\n layer: %d \n", l.type); } @@ -536,7 +554,8 @@ float *network_predict_gpu_cudnn(network net, float *input) state.net = net; //status = cudaMalloc((void **)&(state.input), sizeof(float)*size); state.input = net.input_state_gpu; - status = cudaMemcpy(state.input, input, sizeof(float)*size, cudaMemcpyHostToDevice); + memcpy(net.input_pinned_cpu, input, size * sizeof(float)); + status = cudaMemcpy(state.input, net.input_pinned_cpu, sizeof(float)*size, cudaMemcpyHostToDevice); state.truth = 0; state.train = 0; state.delta = 0; @@ -563,7 +582,8 @@ float *network_predict_gpu_cudnn_quantized(network net, float *input) state.index = 0; state.net = net; status = cudaMalloc((void **)&(state.input), sizeof(float)*size); - status = cudaMemcpy(state.input, input, sizeof(float)*size, cudaMemcpyHostToDevice); + memcpy(net.input_pinned_cpu, input, size * sizeof(float)); + status = cudaMemcpy(state.input, net.input_pinned_cpu, sizeof(float)*size, cudaMemcpyHostToDevice); state.truth = 0; state.train = 0; state.delta = 0; diff --git a/yolo_gpu.vcxproj b/yolo_gpu.vcxproj index fd75fe8..4a7afa4 100644 --- a/yolo_gpu.vcxproj +++ b/yolo_gpu.vcxproj @@ -123,7 +123,7 @@ 3rdparty\lib\x64\pthreadVC2.lib;cublas.lib;curand.lib;cudart.lib;%(AdditionalDependencies) - compute_35,sm_35 + compute_35,sm_35;compute_75,sm_75