From 1f20eb4e8c87b409992aabee303c5bfcafd80f67 Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 12 Aug 2018 20:30:12 +0800 Subject: [PATCH] pack weight and more unroll makes improvement, ~20% faster for conv3x3s2 --- src/layer/arm/convolution_3x3.h | 967 ++++++++++++++++++++++++------ src/layer/arm/convolution_arm.cpp | 10 + src/layer/arm/convolution_arm.h | 1 + 3 files changed, 782 insertions(+), 196 deletions(-) diff --git a/src/layer/arm/convolution_3x3.h b/src/layer/arm/convolution_3x3.h index a0582dc6c92..c8ba516553e 100644 --- a/src/layer/arm/convolution_3x3.h +++ b/src/layer/arm/convolution_3x3.h @@ -11518,6 +11518,70 @@ static void conv3x3s1_winograd64_neon5(const Mat& bottom_blob, Mat& top_blob, co copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt.blob_allocator, opt.num_threads); } +static void conv3x3s2_transform_kernel_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch) +{ + kernel_tm.create(8*9, inch, outch/8 + outch%8); + + const float* kernel = _kernel; + + int p=0; + for (; p+7> 1; - int remain_outch_start = nn_outch << 1; + int nn_outch = outch >> 3; + int remain_outch_start = nn_outch << 3; #pragma omp parallel for num_threads(opt.num_threads) for (int pp=0; pp 0) { asm volatile( - "prfm pldl1keep, [%3, #256] \n" - "ld2 {v8.4s, v9.4s}, [%3], #32 \n"// v8 v9 = r0 - "0: \n" "prfm pldl1keep, [%1, #128] \n" - "ld1 {v6.4s}, [%1] \n"// v6 = _sum0 + "ld1 {v8.4s}, [%1] \n" + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v9.4s}, [%2] \n" - "fmul v12.4s, v8.4s, %12.s[0] \n" + "prfm pldl1keep, [%3, #128] \n" + "ld1 {v10.4s}, [%3] \n" + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v11.4s}, [%4] \n" - "prfm pldl1keep, [%2, #128] \n" - "ld1 {v7.4s}, [%2] \n"// v7 = _sum1 + /// + "prfm pldl1keep, [%9, #256] \n" + "ld2 {v4.4s, v5.4s}, [%9], #32 \n"// v4=00 v5=01 - "fmul v13.4s, v8.4s, %15.s[0] \n" + "ld1 {v0.4s, v1.4s}, [%12], #32 \n" - "prfm pldl1keep, [%3, #128] \n" - "ld2 {v10.4s, v11.4s}, [%3] \n"// v10 + "fmla v8.4s, v4.4s, v0.s[0] \n" + "fmla v9.4s, v4.4s, v0.s[1] \n" - "fmla v6.4s, v9.4s, %12.s[1] \n" + "prfm pldl1keep, [%5, #128] \n" + "ld1 {v12.4s}, [%5] \n" + "prfm pldl1keep, [%6, #128] \n" + "ld1 {v13.4s}, [%6] \n" - "ext v14.16b, v8.16b, v10.16b, #4\n" + "fmla v10.4s, v4.4s, v0.s[2] \n" + "fmla v11.4s, v4.4s, v0.s[3] \n" - "fmla v7.4s, v9.4s, %15.s[1] \n" + "prfm pldl1keep, [%7, #128] \n" + "ld1 {v14.4s}, [%7] \n" + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v15.4s}, [%8] \n" - "prfm pldl1keep, [%4, #256] \n" - "ld2 {v8.4s, v9.4s}, [%4], #32 \n"// r1 + "ld1 {v2.4s, v3.4s}, [%12], #32 \n" - "fmla v12.4s, v14.4s, %12.s[2] \n" - "fmla v13.4s, v14.4s, %15.s[2] \n" + "fmla v12.4s, v4.4s, v1.s[0] \n" + "fmla v13.4s, v4.4s, v1.s[1] \n" + "fmla v14.4s, v4.4s, v1.s[2] \n" + "fmla v15.4s, v4.4s, v1.s[3] \n" - "prfm pldl1keep, [%4, #128] \n" - "ld2 {v10.4s, v11.4s}, [%4] \n" + "prfm pldl1keep, [%9, #256] \n" + "ld2 {v6.4s, v7.4s}, [%9] \n"// v6 - "fmla v6.4s, v8.4s, %13.s[0] \n" - "fmla v7.4s, v8.4s, %16.s[0] \n" + "fmla v8.4s, v5.4s, v2.s[0] \n" + "fmla v9.4s, v5.4s, v2.s[1] \n" + "fmla v10.4s, v5.4s, v2.s[2] \n" + "fmla v11.4s, v5.4s, v2.s[3] \n" - "ext v14.16b, v8.16b, v10.16b, #4\n" + "ext v6.16b, v4.16b, v6.16b, #4 \n"// v6=02 - "fmla v12.4s, v9.4s, %13.s[1] \n" - "fmla v13.4s, v9.4s, %16.s[1] \n" + "ld1 {v0.4s, v1.4s}, [%12], #32 \n" - "prfm pldl1keep, [%5, #256] \n" - "ld2 {v8.4s, v9.4s}, [%5], #32 \n"// r2 + "fmla v12.4s, v5.4s, v3.s[0] \n" + "fmla v13.4s, v5.4s, v3.s[1] \n" + "fmla v14.4s, v5.4s, v3.s[2] \n" + "fmla v15.4s, v5.4s, v3.s[3] \n" - "fmla v6.4s, v14.4s, %13.s[2] \n" - "fmla v7.4s, v14.4s, %16.s[2] \n" + /// + "prfm pldl1keep, [%10, #256] \n" + "ld2 {v4.4s, v5.4s}, [%10], #32 \n"// v4=10 v5=11 - "prfm pldl1keep, [%5, #128] \n" - "ld2 {v10.4s, v11.4s}, [%5] \n" + "fmla v8.4s, v6.4s, v0.s[0] \n" + "fmla v9.4s, v6.4s, v0.s[1] \n" + "fmla v10.4s, v6.4s, v0.s[2] \n" + "fmla v11.4s, v6.4s, v0.s[3] \n" + + "ld1 {v2.4s, v3.4s}, [%12], #32 \n" + + "fmla v12.4s, v6.4s, v1.s[0] \n" + "fmla v13.4s, v6.4s, v1.s[1] \n" + "fmla v14.4s, v6.4s, v1.s[2] \n" + "fmla v15.4s, v6.4s, v1.s[3] \n" + + "fmla v8.4s, v4.4s, v2.s[0] \n" + "fmla v9.4s, v4.4s, v2.s[1] \n" + "fmla v10.4s, v4.4s, v2.s[2] \n" + "fmla v11.4s, v4.4s, v2.s[3] \n" + + "ld1 {v0.4s, v1.4s}, [%12], #32 \n" + + "fmla v12.4s, v4.4s, v3.s[0] \n" + "fmla v13.4s, v4.4s, v3.s[1] \n" + "fmla v14.4s, v4.4s, v3.s[2] \n" + "fmla v15.4s, v4.4s, v3.s[3] \n" + + "prfm pldl1keep, [%10, #256] \n" + "ld2 {v6.4s, v7.4s}, [%10] \n"// v6 + + "fmla v8.4s, v5.4s, v0.s[0] \n" + "fmla v9.4s, v5.4s, v0.s[1] \n" + "fmla v10.4s, v5.4s, v0.s[2] \n" + "fmla v11.4s, v5.4s, v0.s[3] \n" + + "ld1 {v2.4s, v3.4s}, [%12], #32 \n" + + "ext v6.16b, v4.16b, v6.16b, #4 \n"// v6=12 + + "fmla v12.4s, v5.4s, v1.s[0] \n" + "fmla v13.4s, v5.4s, v1.s[1] \n" + "fmla v14.4s, v5.4s, v1.s[2] \n" + "fmla v15.4s, v5.4s, v1.s[3] \n" - "fmla v12.4s, v8.4s, %14.s[0] \n" - "fmla v13.4s, v8.4s, %17.s[0] \n" + /// + "prfm pldl1keep, [%11, #256] \n" + "ld2 {v4.4s, v5.4s}, [%11], #32 \n"// v4=20 v5=21 - "ext v14.16b, v8.16b, v10.16b, #4\n" + "fmla v8.4s, v6.4s, v2.s[0] \n" + "fmla v9.4s, v6.4s, v2.s[1] \n" + "fmla v10.4s, v6.4s, v2.s[2] \n" + "fmla v11.4s, v6.4s, v2.s[3] \n" - "fmla v6.4s, v9.4s, %14.s[1] \n" - "fmla v7.4s, v9.4s, %17.s[1] \n" + "ld1 {v0.4s, v1.4s}, [%12], #32 \n" - "fmla v12.4s, v14.4s, %14.s[2] \n" - "fmla v13.4s, v14.4s, %17.s[2] \n" + "fmla v12.4s, v6.4s, v3.s[0] \n" + "fmla v13.4s, v6.4s, v3.s[1] \n" + "fmla v14.4s, v6.4s, v3.s[2] \n" + "fmla v15.4s, v6.4s, v3.s[3] \n" - "prfm pldl1keep, [%3, #256] \n" - "ld2 {v8.4s, v9.4s}, [%3], #32 \n"// v8 v9 = r0 + "fmla v8.4s, v4.4s, v0.s[0] \n" + "fmla v9.4s, v4.4s, v0.s[1] \n" + "fmla v10.4s, v4.4s, v0.s[2] \n" + "fmla v11.4s, v4.4s, v0.s[3] \n" - "fadd v6.4s, v6.4s, v12.4s \n" - "fadd v7.4s, v7.4s, v13.4s \n" + "ld1 {v2.4s, v3.4s}, [%12], #32 \n" + + "fmla v12.4s, v4.4s, v1.s[0] \n" + "fmla v13.4s, v4.4s, v1.s[1] \n" + "fmla v14.4s, v4.4s, v1.s[2] \n" + "fmla v15.4s, v4.4s, v1.s[3] \n" + + "prfm pldl1keep, [%11, #256] \n" + "ld2 {v6.4s, v7.4s}, [%11] \n"// v6 + + "fmla v8.4s, v5.4s, v2.s[0] \n" + "fmla v9.4s, v5.4s, v2.s[1] \n" + "fmla v10.4s, v5.4s, v2.s[2] \n" + "fmla v11.4s, v5.4s, v2.s[3] \n" + + "ext v6.16b, v4.16b, v6.16b, #4 \n"// v6=22 + + "ld1 {v0.4s, v1.4s}, [%12], #32 \n" + + "fmla v12.4s, v5.4s, v3.s[0] \n" + "fmla v13.4s, v5.4s, v3.s[1] \n" + "fmla v14.4s, v5.4s, v3.s[2] \n" + "fmla v15.4s, v5.4s, v3.s[3] \n" + + "fmla v8.4s, v6.4s, v0.s[0] \n" + "fmla v9.4s, v6.4s, v0.s[1] \n" + "fmla v10.4s, v6.4s, v0.s[2] \n" + "fmla v11.4s, v6.4s, v0.s[3] \n" + + "fmla v12.4s, v6.4s, v1.s[0] \n" + "fmla v13.4s, v6.4s, v1.s[1] \n" + + "st1 {v8.4s}, [%1], #16 \n" + "st1 {v9.4s}, [%2], #16 \n" + + "fmla v14.4s, v6.4s, v1.s[2] \n" + "fmla v15.4s, v6.4s, v1.s[3] \n" + + "st1 {v10.4s}, [%3], #16 \n" + "st1 {v11.4s}, [%4], #16 \n" + + "sub %12, %12, #288 \n" + + "st1 {v12.4s}, [%5], #16 \n" + "st1 {v13.4s}, [%6], #16 \n" "subs %w0, %w0, #1 \n" - "st1 {v6.4s}, [%1], #16 \n" - "st1 {v7.4s}, [%2], #16 \n" + "st1 {v14.4s}, [%7], #16 \n" + "st1 {v15.4s}, [%8], #16 \n" "bne 0b \n" - "sub %3, %3, #32 \n" - : "=r"(nn), // %0 "=r"(outptr0), // %1 "=r"(outptr1), // %2 - "=r"(r0), // %3 - "=r"(r1), // %4 - "=r"(r2) // %5 + "=r"(outptr2), // %3 + "=r"(outptr3), // %4 + "=r"(outptr4), // %5 + "=r"(outptr5), // %6 + "=r"(outptr6), // %7 + "=r"(outptr7), // %8 + "=r"(r0), // %9 + "=r"(r1), // %10 + "=r"(r2), // %11 + "=r"(ktmp) // %12 : "0"(nn), "1"(outptr0), "2"(outptr1), - "3"(r0), - "4"(r1), - "5"(r2), - "w"(_k00), // %12 - "w"(_k03), // %13 - "w"(_k06), // %14 - "w"(_k10), // %15 - "w"(_k13), // %16 - "w"(_k16) // %17 - : "cc", "memory", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + "3"(outptr2), + "4"(outptr3), + "5"(outptr4), + "6"(outptr5), + "7"(outptr6), + "8"(outptr7), + "9"(r0), + "10"(r1), + "11"(r2), + "12"(ktmp) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" ); } -#else +#else // __aarch64__ if (nn > 0) { asm volatile( - "pld [%3, #256] \n" - "vld2.f32 {d16-d19}, [%3]! \n"// q8 q9 = r0 - "0: \n" "pld [%1, #128] \n" - "vld1.f32 {d12-d13}, [%1] \n"// q6 = _sum0 + "vld1.f32 {d16-d17}, [%1] \n" + "pld [%2, #128] \n" + "vld1.f32 {d18-d19}, [%2] \n" - "vmul.f32 q12, q8, %e12[0] \n" + "pld [%3, #128] \n" + "vld1.f32 {d20-d21}, [%3] \n" + "pld [%4, #128] \n" + "vld1.f32 {d22-d23}, [%4] \n" - "pld [%2, #128] \n" - "vld1.f32 {d14-d15}, [%2] \n"// q7 = _sum1 + /// + "pld [%9, #256] \n" + "vld2.f32 {d8-d11}, [%9]! \n"// q4=00 q5=01 - "vmul.f32 q13, q8, %e15[0] \n" + "vld1.f32 {d0-d3}, [%12 :128]! \n" - "pld [%3, #128] \n" - "vld2.f32 {d20-d21}, [%3] \n"// q10 + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d0[1] \n" - "vmla.f32 q6, q9, %e12[1] \n" + "pld [%5, #128] \n" + "vld1.f32 {d24-d25}, [%5] \n" + "pld [%6, #128] \n" + "vld1.f32 {d26-d27}, [%6] \n" - "vext.32 q11, q8, q10, #1 \n" + "vmla.f32 q10, q4, d1[0] \n" + "vmla.f32 q11, q4, d1[1] \n" - "vmla.f32 q7, q9, %e15[1] \n" + "pld [%7, #128] \n" + "vld1.f32 {d28-d29}, [%7] \n" + "pld [%8, #128] \n" + "vld1.f32 {d30-d31}, [%8] \n" - "pld [%4, #256] \n" - "vld2.f32 {d16-d19}, [%4]! \n"// r1 + "vld1.f32 {d4-d7}, [%12 :128]! \n" - "vmla.f32 q12, q11, %f12[0] \n" - "vmla.f32 q13, q11, %f15[0] \n" + "vmla.f32 q12, q4, d2[0] \n" + "vmla.f32 q13, q4, d2[1] \n" + "vmla.f32 q14, q4, d3[0] \n" + "vmla.f32 q15, q4, d3[1] \n" - "pld [%4, #128] \n" - "vld2.f32 {d20-d21}, [%4] \n" + "pld [%9, #128] \n" + "vld2.f32 {d12-d13}, [%9] \n"// q6 - "vmla.f32 q6, q8, %e13[0] \n" - "vmla.f32 q7, q8, %e16[0] \n" + "vmla.f32 q8, q5, d4[0] \n" + "vmla.f32 q9, q5, d4[1] \n" + "vmla.f32 q10, q5, d5[0] \n" + "vmla.f32 q11, q5, d5[1] \n" - "vext.32 q11, q8, q10, #1 \n" + "vext.f32 q6, q4, q6, #1 \n"// q6=02 - "vmla.f32 q12, q9, %e13[1] \n" - "vmla.f32 q13, q9, %e16[1] \n" + "vld1.f32 {d0-d3}, [%12 :128]! \n" - "pld [%5, #256] \n" - "vld2.f32 {d16-d19}, [%5]! \n"// r2 + "vmla.f32 q12, q5, d6[0] \n" + "vmla.f32 q13, q5, d6[1] \n" + "vmla.f32 q14, q5, d7[0] \n" + "vmla.f32 q15, q5, d7[1] \n" - "vmla.f32 q6, q11, %f13[0] \n" - "vmla.f32 q7, q11, %f16[0] \n" + /// + "pld [%10, #256] \n" + "vld2.f32 {d8-d11}, [%10]! \n"// q4=10 q5=11 - "pld [%5, #128] \n" - "vld2.f32 {d20-d21}, [%5] \n" + "vmla.f32 q8, q6, d0[0] \n" + "vmla.f32 q9, q6, d0[1] \n" + "vmla.f32 q10, q6, d1[0] \n" + "vmla.f32 q11, q6, d1[1] \n" - "vmla.f32 q12, q8, %e14[0] \n" - "vmla.f32 q13, q8, %e17[0] \n" + "vld1.f32 {d4-d7}, [%12 :128]! \n" - "vext.32 q11, q8, q10, #1 \n" + "vmla.f32 q12, q6, d2[0] \n" + "vmla.f32 q13, q6, d2[1] \n" + "vmla.f32 q14, q6, d3[0] \n" + "vmla.f32 q15, q6, d3[1] \n" - "vmla.f32 q6, q9, %e14[1] \n" - "vmla.f32 q7, q9, %e17[1] \n" + "vmla.f32 q8, q4, d4[0] \n" + "vmla.f32 q9, q4, d4[1] \n" + "vmla.f32 q10, q4, d5[0] \n" + "vmla.f32 q11, q4, d5[1] \n" - "vmla.f32 q12, q11, %f14[0] \n" - "vmla.f32 q13, q11, %f17[0] \n" + "vld1.f32 {d0-d3}, [%12 :128]! \n" - "pld [%3, #256] \n" - "vld2.f32 {d16-d19}, [%3]! \n"// q8 q9 = r0 + "vmla.f32 q12, q4, d6[0] \n" + "vmla.f32 q13, q4, d6[1] \n" + "vmla.f32 q14, q4, d7[0] \n" + "vmla.f32 q15, q4, d7[1] \n" - "vadd.f32 q6, q6, q12 \n" - "vadd.f32 q7, q7, q13 \n" + "pld [%10, #128] \n" + "vld2.f32 {d12-d13}, [%10] \n"// q6 + + "vmla.f32 q8, q5, d0[0] \n" + "vmla.f32 q9, q5, d0[1] \n" + "vmla.f32 q10, q5, d1[0] \n" + "vmla.f32 q11, q5, d1[1] \n" + + "vld1.f32 {d4-d7}, [%12 :128]! \n" + + "vext.f32 q6, q4, q6, #1 \n"// q6=12 + + "vmla.f32 q12, q5, d2[0] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d3[0] \n" + "vmla.f32 q15, q5, d3[1] \n" + + /// + "pld [%11, #256] \n" + "vld2.f32 {d8-d11}, [%11]! \n"// q4=20 q5=21 + + "vmla.f32 q8, q6, d4[0] \n" + "vmla.f32 q9, q6, d4[1] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d5[1] \n" + + "vld1.f32 {d0-d3}, [%12 :128]! \n" + + "vmla.f32 q12, q6, d6[0] \n" + "vmla.f32 q13, q6, d6[1] \n" + "vmla.f32 q14, q6, d7[0] \n" + "vmla.f32 q15, q6, d7[1] \n" + + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d0[1] \n" + "vmla.f32 q10, q4, d1[0] \n" + "vmla.f32 q11, q4, d1[1] \n" + + "vld1.f32 {d4-d7}, [%12 :128]! \n" + + "vmla.f32 q12, q4, d2[0] \n" + "vmla.f32 q13, q4, d2[1] \n" + "vmla.f32 q14, q4, d3[0] \n" + "vmla.f32 q15, q4, d3[1] \n" + + "pld [%11, #128] \n" + "vld2.f32 {d12-d13}, [%11] \n"// q6 + + "vmla.f32 q8, q5, d4[0] \n" + "vmla.f32 q9, q5, d4[1] \n" + "vmla.f32 q10, q5, d5[0] \n" + "vmla.f32 q11, q5, d5[1] \n" + + "vext.f32 q6, q4, q6, #1 \n"// q6=22 + + "vld1.f32 {d0-d3}, [%12 :128]! \n" + + "vmla.f32 q12, q5, d6[0] \n" + "vmla.f32 q13, q5, d6[1] \n" + "vmla.f32 q14, q5, d7[0] \n" + "vmla.f32 q15, q5, d7[1] \n" + + "vmla.f32 q8, q6, d0[0] \n" + "vmla.f32 q9, q6, d0[1] \n" + "vmla.f32 q10, q6, d1[0] \n" + "vmla.f32 q11, q6, d1[1] \n" + + "vmla.f32 q12, q6, d2[0] \n" + "vmla.f32 q13, q6, d2[1] \n" + + "vst1.f32 {d16-d17}, [%1]! \n" + "vst1.f32 {d18-d19}, [%2]! \n" + + "vmla.f32 q14, q6, d3[0] \n" + "vmla.f32 q15, q6, d3[1] \n" + + "vst1.f32 {d20-d21}, [%3]! \n" + "vst1.f32 {d22-d23}, [%4]! \n" + + "sub %12, %12, #288 \n" + + "vst1.f32 {d24-d25}, [%5]! \n" + "vst1.f32 {d26-d27}, [%6]! \n" "subs %0, #1 \n" - "vst1.f32 {d12-d13}, [%1]! \n" - "vst1.f32 {d14-d15}, [%2]! \n" + "vst1.f32 {d28-d29}, [%7]! \n" + "vst1.f32 {d30-d31}, [%8]! \n" "bne 0b \n" - "sub %3, #32 \n" - : "=r"(nn), // %0 "=r"(outptr0), // %1 "=r"(outptr1), // %2 - "=r"(r0), // %3 - "=r"(r1), // %4 - "=r"(r2) // %5 + "=r"(outptr2), // %3 + "=r"(outptr3), // %4 + "=r"(outptr4), // %5 + "=r"(outptr5), // %6 + "=r"(outptr6), // %7 + "=r"(outptr7), // %8 + "=r"(r0), // %9 + "=r"(r1), // %10 + "=r"(r2), // %11 + "=r"(ktmp) // %12 : "0"(nn), "1"(outptr0), "2"(outptr1), - "3"(r0), - "4"(r1), - "5"(r2), - "w"(_k00), // %12 - "w"(_k03), // %13 - "w"(_k06), // %14 - "w"(_k10), // %15 - "w"(_k13), // %16 - "w"(_k16) // %17 - : "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15" + "3"(outptr2), + "4"(outptr3), + "5"(outptr4), + "6"(outptr5), + "7"(outptr6), + "8"(outptr7), + "9"(r0), + "10"(r1), + "11"(r2), + "12"(ktmp) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15" ); } #endif // __aarch64__ @@ -11790,63 +12085,344 @@ static void conv3x3s2_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke for (; remain>0; remain--) { #if __ARM_NEON - float32x4_t _r00 = vld1q_f32(r0); - float32x4_t _r10 = vld1q_f32(r1); - float32x4_t _r20 = vld1q_f32(r2); +#if __aarch64__ + asm volatile( + "ld1 {v10.4s, v11.4s}, [%11], #32 \n" - float32x4_t _sum0 = vmulq_f32(_r00, _k00); - float32x4_t _sum1 = vmulq_f32(_r00, _k10); - _sum0 = vmlaq_f32(_sum0, _r10, _k03); - _sum1 = vmlaq_f32(_sum1, _r10, _k13); - _sum0 = vmlaq_f32(_sum0, _r20, _k06); - _sum1 = vmlaq_f32(_sum1, _r20, _k16); + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v0.4s}, [%8] \n" - _sum0 = vsetq_lane_f32(*outptr0, _sum0, 3); - _sum1 = vsetq_lane_f32(*outptr1, _sum1, 3); -#if __aarch64__ - *outptr0 = vaddvq_f32(_sum0); - *outptr1 = vaddvq_f32(_sum1); -#else - float32x2_t _ss0 = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0)); - float32x2_t _ss1 = vadd_f32(vget_low_f32(_sum1), vget_high_f32(_sum1)); - float32x2_t _ss01 = vpadd_f32(_ss0, _ss1); + "ld1 {v12.4s, v13.4s}, [%11], #32 \n" - *outptr0 = vget_lane_f32(_ss01, 0); - *outptr1 = vget_lane_f32(_ss01, 1); + "ld1 {v8.s}[0], [%0] \n" + "ld1 {v8.s}[1], [%1] \n" + "ld1 {v8.s}[2], [%2] \n" + "ld1 {v8.s}[3], [%3] \n" + + "fmul v14.4s, v10.4s, v0.s[0] \n" + "fmul v15.4s, v11.4s, v0.s[0] \n" + + "ld1 {v9.s}[0], [%4] \n" + "ld1 {v9.s}[1], [%5] \n" + "ld1 {v9.s}[2], [%6] \n" + "ld1 {v9.s}[3], [%7] \n" + + "ld1 {v10.4s, v11.4s}, [%11], #32 \n" + + "fmla v8.4s, v12.4s, v0.s[1] \n" + "fmla v9.4s, v13.4s, v0.s[1] \n" + + "ld1 {v12.4s, v13.4s}, [%11], #32 \n" + + "fmla v14.4s, v10.4s, v0.s[2] \n" + "fmla v15.4s, v11.4s, v0.s[2] \n" + + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v1.4s}, [%9] \n" + + "ld1 {v10.4s, v11.4s}, [%11], #32 \n" + + "fmla v8.4s, v12.4s, v1.s[0] \n" + "fmla v9.4s, v13.4s, v1.s[0] \n" + + "ld1 {v12.4s, v13.4s}, [%11], #32 \n" + + "fmla v14.4s, v10.4s, v1.s[1] \n" + "fmla v15.4s, v11.4s, v1.s[1] \n" + + "ld1 {v10.4s, v11.4s}, [%11], #32 \n" + + "fmla v8.4s, v12.4s, v1.s[2] \n" + "fmla v9.4s, v13.4s, v1.s[2] \n" + + "prfm pldl1keep, [%10, #128] \n" + "ld1 {v0.4s}, [%10] \n" + + "ld1 {v12.4s, v13.4s}, [%11], #32 \n" + + "fmla v14.4s, v10.4s, v0.s[0] \n" + "fmla v15.4s, v11.4s, v0.s[0] \n" + + "ld1 {v10.4s, v11.4s}, [%11], #32 \n" + + "fmla v8.4s, v12.4s, v0.s[1] \n" + "fmla v9.4s, v13.4s, v0.s[1] \n" + + "fmla v14.4s, v10.4s, v0.s[2] \n" + "fmla v15.4s, v11.4s, v0.s[2] \n" + + "fadd v8.4s, v8.4s, v14.4s \n" + "fadd v9.4s, v9.4s, v15.4s \n" + + "sub %11, %11, #288 \n" + + "st1 {v8.s}[0], [%0], #4 \n" + "st1 {v8.s}[1], [%1], #4 \n" + "st1 {v8.s}[2], [%2], #4 \n" + "st1 {v8.s}[3], [%3], #4 \n" + + "st1 {v9.s}[0], [%4], #4 \n" + "st1 {v9.s}[1], [%5], #4 \n" + "st1 {v9.s}[2], [%6], #4 \n" + "st1 {v9.s}[3], [%7], #4 \n" + + : "=r"(outptr0), // %0 + "=r"(outptr1), // %1 + "=r"(outptr2), // %2 + "=r"(outptr3), // %3 + "=r"(outptr4), // %4 + "=r"(outptr5), // %5 + "=r"(outptr6), // %6 + "=r"(outptr7), // %7 + "=r"(r0), // %8 + "=r"(r1), // %9 + "=r"(r2), // %10 + "=r"(ktmp) // %11 + : "0"(outptr0), + "1"(outptr1), + "2"(outptr2), + "3"(outptr3), + "4"(outptr4), + "5"(outptr5), + "6"(outptr6), + "7"(outptr7), + "8"(r0), + "9"(r1), + "10"(r2), + "11"(ktmp) + : "memory", "v0", "v1", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + ); +#else // __aarch64__ + asm volatile( + "vld1.f32 {d20-d23}, [%11 :128]! \n" + + "pld [%8, #128] \n" + "vld1.f32 {d0-d1}, [%8] \n" + + "vld1.f32 {d24-d27}, [%11 :128]! \n" + + "vld1.f32 {d16[0]}, [%0] \n" + "vld1.f32 {d16[1]}, [%1] \n" + "vld1.f32 {d17[0]}, [%2] \n" + "vld1.f32 {d17[1]}, [%3] \n" + + "vmul.f32 q14, q10, d0[0] \n" + "vmul.f32 q15, q11, d0[0] \n" + + "vld1.f32 {d18[0]}, [%4] \n" + "vld1.f32 {d18[1]}, [%5] \n" + "vld1.f32 {d19[0]}, [%6] \n" + "vld1.f32 {d19[1]}, [%7] \n" + + "vld1.f32 {d20-d23}, [%11 :128]! \n" + + "vmla.f32 q8, q12, d0[1] \n" + "vmla.f32 q9, q13, d0[1] \n" + + "vld1.f32 {d24-d27}, [%11 :128]! \n" + + "vmla.f32 q14, q10, d1[0] \n" + "vmla.f32 q15, q11, d1[0] \n" + + "pld [%9, #128] \n" + "vld1.f32 {d2-d3}, [%9] \n" + + "vld1.f32 {d20-d23}, [%11 :128]! \n" + + "vmla.f32 q8, q12, d2[0] \n" + "vmla.f32 q9, q13, d2[0] \n" + + "vld1.f32 {d24-d27}, [%11 :128]! \n" + + "vmla.f32 q14, q10, d2[1] \n" + "vmla.f32 q15, q11, d2[1] \n" + + "vld1.f32 {d20-d23}, [%11 :128]! \n" + + "vmla.f32 q8, q12, d3[0] \n" + "vmla.f32 q9, q13, d3[0] \n" + + "pld [%10, #128] \n" + "vld1.f32 {d0-d1}, [%10] \n" + + "vld1.f32 {d24-d27}, [%11 :128]! \n" + + "vmla.f32 q14, q10, d0[0] \n" + "vmla.f32 q15, q11, d0[0] \n" + + "vld1.f32 {d20-d23}, [%11 :128]! \n" + + "vmla.f32 q8, q12, d0[1] \n" + "vmla.f32 q9, q13, d0[1] \n" + + "vmla.f32 q14, q10, d1[0] \n" + "vmla.f32 q15, q11, d1[0] \n" + + "vadd.f32 q8, q8, q14 \n" + "vadd.f32 q9, q9, q15 \n" + + "sub %11, %11, #288 \n" + + "vst1.f32 {d16[0]}, [%0]! \n" + "vst1.f32 {d16[1]}, [%1]! \n" + "vst1.f32 {d17[0]}, [%2]! \n" + "vst1.f32 {d17[1]}, [%3]! \n" + + "vst1.f32 {d18[0]}, [%4]! \n" + "vst1.f32 {d18[1]}, [%5]! \n" + "vst1.f32 {d19[0]}, [%6]! \n" + "vst1.f32 {d19[1]}, [%7]! \n" + + : "=r"(outptr0), // %0 + "=r"(outptr1), // %1 + "=r"(outptr2), // %2 + "=r"(outptr3), // %3 + "=r"(outptr4), // %4 + "=r"(outptr5), // %5 + "=r"(outptr6), // %6 + "=r"(outptr7), // %7 + "=r"(r0), // %8 + "=r"(r1), // %9 + "=r"(r2), // %10 + "=r"(ktmp) // %11 + : "0"(outptr0), + "1"(outptr1), + "2"(outptr2), + "3"(outptr3), + "4"(outptr4), + "5"(outptr5), + "6"(outptr6), + "7"(outptr7), + "8"(r0), + "9"(r1), + "10"(r2), + "11"(ktmp) + : "memory", "q0", "q1", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15" + ); #endif // __aarch64__ -#else +#else // __ARM_NEON float sum0 = 0.f; float sum1 = 0.f; + float sum2 = 0.f; + float sum3 = 0.f; + float sum4 = 0.f; + float sum5 = 0.f; + float sum6 = 0.f; + float sum7 = 0.f; - sum0 += r0[0] * k0[0]; - sum0 += r0[1] * k0[1]; - sum0 += r0[2] * k0[2]; - sum0 += r1[0] * k0[3]; - sum0 += r1[1] * k0[4]; - sum0 += r1[2] * k0[5]; - sum0 += r2[0] * k0[6]; - sum0 += r2[1] * k0[7]; - sum0 += r2[2] * k0[8]; - - sum1 += r0[0] * k1[0]; - sum1 += r0[1] * k1[1]; - sum1 += r0[2] * k1[2]; - sum1 += r1[0] * k1[3]; - sum1 += r1[1] * k1[4]; - sum1 += r1[2] * k1[5]; - sum1 += r2[0] * k1[6]; - sum1 += r2[1] * k1[7]; - sum1 += r2[2] * k1[8]; + sum0 += r0[0] * ktmp[0]; + sum1 += r0[0] * ktmp[1]; + sum2 += r0[0] * ktmp[2]; + sum3 += r0[0] * ktmp[3]; + sum4 += r0[0] * ktmp[4]; + sum5 += r0[0] * ktmp[5]; + sum6 += r0[0] * ktmp[6]; + sum7 += r0[0] * ktmp[7]; + ktmp += 8; + + sum0 += r0[1] * ktmp[0]; + sum1 += r0[1] * ktmp[1]; + sum2 += r0[1] * ktmp[2]; + sum3 += r0[1] * ktmp[3]; + sum4 += r0[1] * ktmp[4]; + sum5 += r0[1] * ktmp[5]; + sum6 += r0[1] * ktmp[6]; + sum7 += r0[1] * ktmp[7]; + ktmp += 8; + + sum0 += r0[2] * ktmp[0]; + sum1 += r0[2] * ktmp[1]; + sum2 += r0[2] * ktmp[2]; + sum3 += r0[2] * ktmp[3]; + sum4 += r0[2] * ktmp[4]; + sum5 += r0[2] * ktmp[5]; + sum6 += r0[2] * ktmp[6]; + sum7 += r0[2] * ktmp[7]; + ktmp += 8; + + sum0 += r1[0] * ktmp[0]; + sum1 += r1[0] * ktmp[1]; + sum2 += r1[0] * ktmp[2]; + sum3 += r1[0] * ktmp[3]; + sum4 += r1[0] * ktmp[4]; + sum5 += r1[0] * ktmp[5]; + sum6 += r1[0] * ktmp[6]; + sum7 += r1[0] * ktmp[7]; + ktmp += 8; + + sum0 += r1[1] * ktmp[0]; + sum1 += r1[1] * ktmp[1]; + sum2 += r1[1] * ktmp[2]; + sum3 += r1[1] * ktmp[3]; + sum4 += r1[1] * ktmp[4]; + sum5 += r1[1] * ktmp[5]; + sum6 += r1[1] * ktmp[6]; + sum7 += r1[1] * ktmp[7]; + ktmp += 8; + + sum0 += r1[2] * ktmp[0]; + sum1 += r1[2] * ktmp[1]; + sum2 += r1[2] * ktmp[2]; + sum3 += r1[2] * ktmp[3]; + sum4 += r1[2] * ktmp[4]; + sum5 += r1[2] * ktmp[5]; + sum6 += r1[2] * ktmp[6]; + sum7 += r1[2] * ktmp[7]; + ktmp += 8; + + sum0 += r2[0] * ktmp[0]; + sum1 += r2[0] * ktmp[1]; + sum2 += r2[0] * ktmp[2]; + sum3 += r2[0] * ktmp[3]; + sum4 += r2[0] * ktmp[4]; + sum5 += r2[0] * ktmp[5]; + sum6 += r2[0] * ktmp[6]; + sum7 += r2[0] * ktmp[7]; + ktmp += 8; + + sum0 += r2[1] * ktmp[0]; + sum1 += r2[1] * ktmp[1]; + sum2 += r2[1] * ktmp[2]; + sum3 += r2[1] * ktmp[3]; + sum4 += r2[1] * ktmp[4]; + sum5 += r2[1] * ktmp[5]; + sum6 += r2[1] * ktmp[6]; + sum7 += r2[1] * ktmp[7]; + ktmp += 8; + + sum0 += r2[2] * ktmp[0]; + sum1 += r2[2] * ktmp[1]; + sum2 += r2[2] * ktmp[2]; + sum3 += r2[2] * ktmp[3]; + sum4 += r2[2] * ktmp[4]; + sum5 += r2[2] * ktmp[5]; + sum6 += r2[2] * ktmp[6]; + sum7 += r2[2] * ktmp[7]; + ktmp += 8; *outptr0 += sum0; *outptr1 += sum1; -#endif // __ARM_NEON + *outptr2 += sum2; + *outptr3 += sum3; + *outptr4 += sum4; + *outptr5 += sum5; + *outptr6 += sum6; + *outptr7 += sum7; + ktmp -= 8*9; + + outptr0++; + outptr1++; + outptr2++; + outptr3++; + outptr4++; + outptr5++; + outptr6++; + outptr7++; +#endif // __ARM_NEON r0 += 2; r1 += 2; r2 += 2; - outptr0++; - outptr1++; } r0 += tailstep; @@ -11854,8 +12430,7 @@ static void conv3x3s2_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke r2 += tailstep; } - k0 += 9; - k1 += 9; + ktmp += 8*9; } } @@ -11868,7 +12443,7 @@ static void conv3x3s2_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _ke out.fill(bias0); - const float* kernel0 = kernel + p*inch*9; + const float* ktmp = _kernel.channel(p/8 + p%8); for (int q=0; q