Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Better groupwise conv for small number of channels per group (#145)
Summary: Pull Request resolved: #145 Groupwise convolution now supports C_per_G = 2, 4, 8, 16 and stride = 1 and 2 We can also parallelize groupwise now. ~~TODO: Add performance numbers. ~~ {F218426946} The following results are on T1. **Existing Implementation on Resnext shapes** ``` MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K, GOPS 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 3136, 4, 1152, 0.75 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, direct, 3136, 4, 1152, 24.01 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 8, 2304, 2.93 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 8, 2304, 34.34 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.34 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 41.46 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.35 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 41.82 ``` **This diff** ``` MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K, GOPS 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 3136, 4, 1152, 0.75 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, direct, 3136, 4, 1152, 14.82 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 8, 2304, 2.92 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 8, 2304, 28.93 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.32 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 41.80 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.35 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 41.89 ``` Note: The first shape is slower probably due to the more frequent calling of output processing. Needs further analysis. **Extra shapes supported in this diff** ``` 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 784, 8, 2304, 2.72 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, direct, 784, 8, 2304, 23.50 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 196, 16, 4608, 9.84 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, direct, 196, 16, 4608, 35.08 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 2, 576, 0.18 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 2, 576, 7.50 ``` **This diff with with OMP_NUM_THREADS=2** ``` MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K, GOPS 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 3136, 4, 1152, 0.75 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, direct, 3136, 4, 1152, 30.57 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 8, 2304, 2.92 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 8, 2304, 52.47 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.27 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 68.01 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.28 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 66.12 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 784, 8, 2304, 2.77 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, direct, 784, 8, 2304, 47.90 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 196, 16, 4608, 9.59 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, direct, 196, 16, 4608, 60.01 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 2, 576, 0.18 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 2, 576, 9.03 ``` **This diff with with OMP_NUM_THREADS=4** ``` MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K, GOPS 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 3136, 4, 1152, 0.75 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, direct, 3136, 4, 1152, 48.39 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 8, 2304, 2.92 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 8, 2304, 80.02 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.26 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 91.99 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.24 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 93.38 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 784, 8, 2304, 2.85 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, direct, 784, 8, 2304, 75.54 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 196, 16, 4608, 9.84 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, direct, 196, 16, 4608, 86.84 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 2, 576, 0.18 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 2, 576, 9.88 ``` **This diff with with OMP_NUM_THREADS=8** ``` MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K, GOPS 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 3136, 4, 1152, 0.75 1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1, direct, 3136, 4, 1152, 81.98 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 8, 2304, 2.92 1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 8, 2304, 104.51 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.25 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 115.44 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 196, 16, 4608, 10.24 1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1, direct, 196, 16, 4608, 113.22 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 784, 8, 2304, 2.85 1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1, direct, 784, 8, 2304, 97.82 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, FusedIm2Col, 196, 16, 4608, 9.81 1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1, direct, 196, 16, 4608, 113.20 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, FusedIm2Col, 784, 2, 576, 0.18 1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1, direct, 784, 2, 576, 9.99 ``` Reviewed By: jspark1105 Differential Revision: D17934344 fbshipit-source-id: cc4f8100f6f2c4e7e66c40f332121bfb7355c551
- Loading branch information