forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add direct conv to fbgemm (pytorch#901)
Summary: Pull Request resolved: pytorch#901 Integrate direct convolution code path into Fbgemm. The direct convolution entrance is integrated through FbgemmConv.cc. ConvFastPath will automatically determine if this case can use our direct convolution branch: - if spatial_dim=2, kh=2, kw<=6, stride=1 or 2, padding=0 Reviewed By: jspark1105 Differential Revision: D32273614 fbshipit-source-id: 16255395b4e14fae10c129ad98dd4445a5106989
- Loading branch information
1 parent
969941d
commit f0f6ca7
Showing
11 changed files
with
1,639 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* All rights reserved. | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <array> | ||
#include <cstdint> | ||
#include <vector> | ||
#include "fbgemm/ConvUtils.h" | ||
#include "fbgemm/FbgemmBuild.h" | ||
#include "fbgemm/UtilsAvx2.h" | ||
|
||
namespace fbgemm { | ||
|
||
class FBGEMM_API PackedDirectConvMatrix { | ||
public: | ||
/** | ||
* @param IC the number of input channels | ||
* @param OC the number of output channels | ||
* @param kernel_prod the product of all kernels. For example, kernel_prod = | ||
* 9 for 3x3 conv, and 27 for 3x3x3 conv. | ||
* @param smat the source unpacked weight in GRS layout | ||
*/ | ||
PackedDirectConvMatrix( | ||
int IC_per_G, | ||
int OC_per_G, | ||
int filter_prod, | ||
const std::int8_t* smat); | ||
virtual ~PackedDirectConvMatrix(); | ||
|
||
const std::int8_t* PackedMat() const { | ||
return pmat_; | ||
} | ||
|
||
const bool& is_first_call() const { | ||
return first_call; | ||
} | ||
|
||
/** | ||
compute the column offsets of the weight matrix. | ||
output of this function is the col_offsets vector | ||
col_offses dimension is the same as conv_p.OUT_DIM | ||
*/ | ||
template <int kSpatialDim> | ||
FBGEMM_API void col_offsets_with_zero_pt_s8acc32_DirectConvT( | ||
const fbgemm::conv_param_t<kSpatialDim>& conv_p, | ||
std::int32_t* B_zero_point, | ||
std::vector<int32_t>& col_offsets, | ||
int ncols_per_quant_group); | ||
|
||
private: | ||
std::int8_t* pmat_; /** packed weight */ | ||
bool first_call{true}; | ||
}; | ||
|
||
} // namespace fbgemm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,8 @@ enum class optimized_conv_t { | |
groupwise, | ||
pointwise, | ||
fastpath1d, | ||
im2col | ||
im2col, | ||
directconv | ||
}; | ||
|
||
/** | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.