Skip to content

Commit

Permalink
[PyTorch Edge] Add Quantized Softmax Op (Naive Implementation) (pytor…
Browse files Browse the repository at this point in the history
…ch#75017)

Summary:
Pull Request resolved: pytorch#75017

This version just does dequantize, fp32 softmax, quantize.
Another version of actual quantized softmax using qnnpack will be added next

Test Plan:
From fbcode:
```buck test caffe2/test:quantization -- test_qsoftmax```

Benchmarking: See summary of D34996486

Reviewed By: kimishpatel

Differential Revision: D34943147

fbshipit-source-id: 426a0780803597a21460139c67960891d6e9cc81
(cherry picked from commit 524eede)
  • Loading branch information
salilsdesai authored and pytorchmergebot committed Mar 31, 2022
1 parent 8b8f3e8 commit 8d7242a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
27 changes: 27 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qsoftmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include <ATen/ATen.h>
#include <torch/library.h>

namespace at {
namespace native {

namespace {

Tensor qsoftmax(
const Tensor& qx,
const int64_t dim,
const double output_scale,
const int64_t output_zero_point) {
Tensor rx = at::dequantize(qx);
Tensor ry = at::softmax(rx, dim);
return at::quantize_per_tensor(
ry, output_scale, output_zero_point, qx.scalar_type());
}

TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::softmax"), TORCH_FN(qsoftmax));
}

} // namespace

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::softmax(Tensor qx, int dim, float output_scale, int output_zero_point) -> Tensor"));
}

// According to #33294: The "_" prefix registration will be
Expand Down
34 changes: 34 additions & 0 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,40 @@ def test_qmatmul(self, num_dims, outer_dims, m, k, n, dtypes):
scale_C,
zero_point_C)

"""Tests the correctness of the quantized softmax op."""
@given(num_dims=st.integers(2, 4),
dims=st.lists(st.integers(2, 5), min_size=5, max_size=5))
def test_qsoftmax(self, num_dims, dims):
size = dims[:num_dims]
torch_dtype = torch.quint8
np_dtype = np.uint8
dim = num_dims - 1

scale_X = 1.3
zero_point_X = 0
X = torch.rand(size=size, dtype=torch.float32) * 8 + zero_point_X

scale_Y = 1 / 256
zero_point_Y = 0

qX = torch.quantize_per_tensor(X,
scale=scale_X,
zero_point=zero_point_X,
dtype=torch_dtype)


# softmax ground truth
Y = torch.softmax(qX.dequantize(), dim=dim).numpy()
qY = _quantize(Y, scale_Y, zero_point_Y, dtype=np_dtype)
qY_hat = torch.ops.quantized.softmax(qX,
dim=dim,
output_scale=scale_Y,
output_zero_point=zero_point_Y)

np.testing.assert_equal(qY, qY_hat.int_repr(),
"Quantized softmax failed.")


"""Tests the correctness of the mul and mul_relu op."""
def test_qmul_broadcast(self):
mul_relu = torch.ops.quantized.mul_relu
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/quantized/cpu/qreduction.cpp",
"aten/src/ATen/native/quantized/cpu/qrelu.cpp",
"aten/src/ATen/native/quantized/cpu/qsigmoid.cpp",
"aten/src/ATen/native/quantized/cpu/qsoftmax.cpp",
"aten/src/ATen/native/quantized/cpu/qsort.cpp",
"aten/src/ATen/native/quantized/cpu/qtanh.cpp",
"aten/src/ATen/native/quantized/cpu/qthreshold.cpp",
Expand Down

0 comments on commit 8d7242a

Please sign in to comment.