Skip to content

Commit

Permalink
Revert "Revert D24335982: explicitly error out in comparison ops when…
Browse files Browse the repository at this point in the history
… the types don't match" (pytorch#47288)

Summary:
Pull Request resolved: pytorch#47288

This reverts commit b3eb0c8.

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D24706531

Pulled By: bdhirsh

fbshipit-source-id: f3bf34ddba7882932155819251b6c7dcb5c6b56c
  • Loading branch information
bdhirsh authored and facebook-github-bot committed Nov 4, 2020
1 parent e4bc785 commit fe17269
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 61 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ Vec256<BFloat16> inline Vec256<BFloat16>::operator==(const Vec256<BFloat16>& oth
}
Vec256<BFloat16> inline Vec256<BFloat16>::operator!=(const Vec256<BFloat16>& other) const {
return bfloat16_binary_op_as_fp32(*this, other, [](__m256 x, __m256 y) {
return _mm256_cmp_ps(x, y, _CMP_NEQ_OQ);
return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ);
});
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ template <> class Vec256<c10::complex<double>> {
return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
}
Vec256<c10::complex<double>> operator!=(const Vec256<c10::complex<double>>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ);
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
}
Vec256<c10::complex<double>> operator<(const Vec256<c10::complex<double>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ template <> class Vec256<c10::complex<float>> {
return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
}
Vec256<c10::complex<float>> operator!=(const Vec256<c10::complex<float>>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ);
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
}
Vec256<c10::complex<float>> operator<(const Vec256<c10::complex<float>>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ template <> class Vec256<double> {
}

Vec256<double> operator!=(const Vec256<double>& other) const {
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ);
return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
}

Vec256<double> operator<(const Vec256<double>& other) const {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ template <> class Vec256<float> {
}

Vec256<float> operator!=(const Vec256<float>& other) const {
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ);
return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
}

Vec256<float> operator<(const Vec256<float>& other) const {
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,8 +1008,8 @@ Tensor var(const Tensor& self, bool unbiased) {
return trivial_return.value();
}

// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// See https://github.com/pytorch/pytorch/pull/43858.
if (self.device().type() == kCPU) {
return at::_var(self, unbiased);
Expand Down Expand Up @@ -1040,8 +1040,8 @@ Tensor std(const Tensor& self, bool unbiased) {
return trivial_return.value();
}

// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// NOTE: CPU performance significantly regressed when attempting to port to ATen,
// so this dispatches differently based on device type.
// See https://github.com/pytorch/pytorch/pull/43858.
if (self.device().type() == kCPU) {
return at::_std(self, unbiased);
Expand Down
23 changes: 22 additions & 1 deletion aten/src/ATen/native/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,15 @@ void TensorIterator::compute_types(const TensorIteratorConfig& config) {
op.tensor.options().dtype(common_dtype_),
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}

// Promotes inputs by creating temporaries of the correct dtype
if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
op.original_tensor = op.tensor;
op.tensor = op.tensor.to(common_dtype_);
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}
}
}
Expand Down Expand Up @@ -847,14 +849,33 @@ TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a,

TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
const Tensor& b) {
return TensorIteratorConfig()
// Note [special-case bool outputs]
// We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
// has `bool` dtype. This is a performance optimization: the functional
// version of all comparison/logical ops uses a bool output tensor, and we'd like to
// avoid creating a temporary copy of the output.
// However, note that all kernels using this TensorIterator will need to special-case when
// the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
if (out.scalar_type() == kBool) {
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.build();
} else {
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.build();
}
}

TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {
Expand Down
103 changes: 53 additions & 50 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,16 @@ void lshift_kernel(TensorIterator& iter) {
}

void logical_and_kernel(TensorIterator& iter) {
// We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because
// common_dtype() is unavailable for bfloat16.
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_and_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a && b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_and_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_and_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(a && b);
Expand All @@ -254,37 +253,35 @@ void logical_and_kernel(TensorIterator& iter) {
}

void logical_or_kernel(TensorIterator& iter) {
// We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because
// common_dtype() is unavailable for bfloat16.
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_or_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a || b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.dtype(), "logical_or_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_or_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(a || b);
});
});
});
}
}

void logical_xor_kernel(TensorIterator& iter) {
// We use if-else here specifically for bool instead of using iter.common_dtype() like the CUDA implementation because
// common_dtype() is unavailable for bfloat16.
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "logical_xor_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return bool(a) != bool(b);
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "logical_xor_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "logical_xor_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(bool(a) != bool(b));
Expand All @@ -311,21 +308,22 @@ void rshift_kernel(TensorIterator& iter) {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return a >> b;
});
});
});
}
}

void lt_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "lt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a < b;
});
[](scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "lt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -334,20 +332,21 @@ void lt_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.lt(b);
});
});
});
}
}

void le_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "le_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a <= b;
});
[](scalar_t a, scalar_t b) -> bool {
return a <= b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "le_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -356,20 +355,21 @@ void le_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.le(b);
});
});
});
}
}

void gt_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "gt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
cpu_kernel(iter,
[=](scalar_t a, scalar_t b) -> bool {
return a > b;
});
[](scalar_t a, scalar_t b) -> bool {
return a > b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "gt_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -378,20 +378,21 @@ void gt_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.gt(b);
});
});
});
}
}

void ge_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ge_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a >= b;
});
[](scalar_t a, scalar_t b) -> bool {
return a >= b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "ge_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -400,20 +401,21 @@ void ge_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.ge(b);
});
});
});
}
}

void eq_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "eq_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a == b;
});
[](scalar_t a, scalar_t b) -> bool {
return a == b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "eq_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -422,20 +424,21 @@ void eq_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.eq(b);
});
});
});
}
}

void ne_kernel(TensorIterator& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ne_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a != b;
});
[](scalar_t a, scalar_t b) -> bool {
return a != b;
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "ne_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
Expand All @@ -444,7 +447,7 @@ void ne_kernel(TensorIterator& iter) {
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t> {
return a.ne(b);
});
});
});
}
}

Expand Down
Loading

0 comments on commit fe17269

Please sign in to comment.