Skip to content

Commit

Permalink
fix cpplint error for the autmic max/min
Browse files Browse the repository at this point in the history
fix cpplint error for the autmic max/min
  • Loading branch information
ZHUI authored Sep 26, 2020
1 parent ecfdfc9 commit a85592b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
17 changes: 8 additions & 9 deletions paddle/fluid/operators/math/segment_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include <algorithm>
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -100,7 +99,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
Index segment_offset, dim_index_base, actual_height;
Index inner_dim_size = h.inner_dim_size;
h.calculate(stripe_index, segment_offset, dim_index_base, actual_height);
h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height);

T minmax = pool.initial();
Index first_segment_id = segment_ids[dim_index_base];
Expand Down Expand Up @@ -154,7 +153,7 @@ __global__ void SegmentIndexGradKernel(const Index* segment_ids, const T* input,
T* in_grad, Helper h) {
CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
Index segment_offset, dim_index_base, actual_height;
h.calculate(stripe_index, segment_offset, dim_index_base, actual_height);
h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height);

for (Index j = 0; j < actual_height; j++) {
Index current_segment_id = segment_ids[dim_index_base + j];
Expand Down Expand Up @@ -217,11 +216,11 @@ class ArrangeHelper {
total_stripe_count = inner_dim_size * input_outer_dim_num_stripe;
}

DEVICE inline void calculate(T stripe_index, T& segment_offset,
T& dim_index_base, T& actual_height) {
segment_offset = stripe_index % inner_dim_size;
dim_index_base = stripe_index / inner_dim_size * DimTileSize;
actual_height = min(DimTileSize, input_length_size - dim_index_base);
DEVICE inline void calculate(T stripe_index, T* segment_offset,
T* dim_index_base, T* actual_height) {
*segment_offset = stripe_index % inner_dim_size;
*dim_index_base = stripe_index / inner_dim_size * DimTileSize;
*actual_height = min(DimTileSize, input_length_size - *dim_index_base);
}
};

Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/platform/cuda_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ USE_CUDA_ATOMIC(Max, unsigned int);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
USE_CUDA_ATOMIC(Max, unsigned long long int); // NOLINT
#else
CUDA_ATOMIC_WRAPPER(Max, unsigned long long int) {
CUDA_ATOMIC_WRAPPER(Max, unsigned long long int) { // NOLINT
if (*address >= val) {
return;
}

unsigned long long int old = *address, assumed;
unsigned long long int old = *address, assumed; // NOLINT

do {
assumed = old;
Expand All @@ -169,7 +169,7 @@ CUDA_ATOMIC_WRAPPER(Max, float) {
return;
}

int *const address_as_i = (int *)address;
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;

do {
Expand All @@ -187,9 +187,9 @@ CUDA_ATOMIC_WRAPPER(Max, double) {
return;
}

unsigned long long int *const address_as_ull =
(unsigned long long int *)address;
unsigned long long int old = *address_as_ull, assumed;
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT

do {
assumed = old;
Expand All @@ -209,12 +209,12 @@ USE_CUDA_ATOMIC(Min, unsigned int);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
USE_CUDA_ATOMIC(Min, unsigned long long int); // NOLINT
#else
CUDA_ATOMIC_WRAPPER(Min, unsigned long long int) {
CUDA_ATOMIC_WRAPPER(Min, unsigned long long int) { // NOLINT
if (*address <= val) {
return;
}

unsigned long long int old = *address, assumed;
unsigned long long int old = *address, assumed; // NOLINT

do {
assumed = old;
Expand All @@ -241,7 +241,7 @@ CUDA_ATOMIC_WRAPPER(Min, float) {
return;
}

int *const address_as_i = (int *)address;
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;

do {
Expand All @@ -259,9 +259,9 @@ CUDA_ATOMIC_WRAPPER(Min, double) {
return;
}

unsigned long long int *const address_as_ull =
(unsigned long long int *)address;
unsigned long long int old = *address_as_ull, assumed;
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT

do {
assumed = old;
Expand Down

0 comments on commit a85592b

Please sign in to comment.