Skip to content

Commit

Permalink
Fix descending sort when minimum values are present
Browse files Browse the repository at this point in the history
Applying an unary minus transformation on the input overflows when the
minimum values are present. Use a greater than comparator instead.

Fixes pytorch#865.
  • Loading branch information
asuhan committed Sep 5, 2019
1 parent f7fc05a commit e0eff17
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
13 changes: 13 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,19 @@ TEST_F(AtenXlaTensorTest, TestSort) {
}
}

TEST_F(AtenXlaTensorTest, TestSortDescWithMinValue) {
std::vector<int8_t> values{-128, 100};
torch::Tensor input =
torch::tensor(values, torch::TensorOptions(torch::kChar));
auto output = torch::sort(input, /*dim=*/0, /*descending=*/true);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
auto xla_output = torch::sort(xla_input, /*dim=*/0, /*descending=*/true);
AllEqual(std::get<0>(output), std::get<0>(xla_output));
AllEqual(std::get<1>(output), std::get<1>(xla_output));
});
}

TEST_F(AtenXlaTensorTest, TestArgSort) {
torch::Tensor a = torch::rand({4, 5, 3}, torch::TensorOptions(torch::kFloat));
for (int k = 1; k <= 3; ++k) {
Expand Down
22 changes: 10 additions & 12 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,31 +168,29 @@ std::vector<xla::XlaOp> CreateKthValue(const xla::XlaOp& input, xla::int64 k,
std::vector<xla::XlaOp> CreateTopK(const xla::XlaOp& input, xla::int64 k,
xla::int64 dim, bool largest,
bool /* sorted */) {
auto identity = [](const xla::XlaOp& op) -> xla::XlaOp { return op; };
auto neg = [](const xla::XlaOp& op) -> xla::XlaOp { return xla::Neg(op); };
auto input_transform = largest ? neg : identity;

// Here 'k' is 1 based (1...).
xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input);
XLA_CHECK_LE(k, shape.dimensions(dim));
xla::Shape iota_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
xla::XlaOp iota = xla::Iota(input.builder(), iota_shape, dim);
xla::XlaOp sort_result = xla::Sort(
{input_transform(input), iota},
xla::CreateScalarLtComputation(
{shape.element_type(), xla::PrimitiveType::S32}, input.builder()),
dim);
xla::XlaComputation comparator =
largest ? xla::CreateScalarGtComputation(
{shape.element_type(), xla::PrimitiveType::S32},
input.builder())
: xla::CreateScalarLtComputation(
{shape.element_type(), xla::PrimitiveType::S32},
input.builder());
xla::XlaOp sort_result = xla::Sort({input, iota}, comparator, dim);

std::vector<xla::int64> start_indices(shape.rank(), 0);
std::vector<xla::int64> limit_indices(shape.dimensions().begin(),
shape.dimensions().end());
limit_indices[dim] = k;
std::vector<xla::int64> strides(shape.rank(), 1);

xla::XlaOp values =
input_transform(xla::Slice(xla::GetTupleElement(sort_result, 0),
start_indices, limit_indices, strides));
xla::XlaOp values = xla::Slice(xla::GetTupleElement(sort_result, 0),
start_indices, limit_indices, strides);
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
start_indices, limit_indices, strides);
// aten::topk() wants Long tensors as indices.
Expand Down

0 comments on commit e0eff17

Please sign in to comment.