Skip to content

Commit

Permalink
[XLA:TPU] Allow user to override reduction input size in TPU ANN.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 407582874
Change-Id: I96e6fd2caafbf7e90c028382d53edab9bbc1c4e2
  • Loading branch information
tensorflower-gardener committed Nov 4, 2021
1 parent 61e4fd4 commit b950ca5
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 6 deletions.
14 changes: 13 additions & 1 deletion tensorflow/compiler/xla/client/lib/approx_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ XlaOp SortAndSliceBuilder(XlaBuilder* builder, absl::Span<const XlaOp> operands,
XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values, int64_t top_k,
int64_t reduction_dim, const XlaComputation& comparator,
float recall_target, bool aggregate_to_topk) {
float recall_target, bool aggregate_to_topk,
int64_t reduction_input_size_override) {
// Validates shapes and ranks
if (operands.size() != init_values.size()) {
return builder->ReportError(
Expand Down Expand Up @@ -170,6 +171,17 @@ XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands,
}
return Tuple(builder, operands);
}
// Only override the input size when we really need to compute the ApproxTopK
// through the PartialReduce TPU Op.
if (reduction_input_size_override >= 0) {
if (n < reduction_input_size_override) {
return builder->ReportError(
InvalidArgument("reduction_input_size_override should be greater "
"equals to opeands[reduction_dim], which is %d",
n));
}
n = reduction_input_size_override;
}

auto status_or_approx_output_size = ApproxTopKReductionOutputSize(
n, rank, top_k, recall_target, /*aggregate_to_topk=*/false);
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/compiler/xla/client/lib/approx_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ namespace xla {
// only keep the final k elements on TPU. This option is useful when user
// wanted to forward the approximate results to host and aggregate the results
// on CPU for better throughput.
// reduction_input_size_override: When set to a positive value, it overrides the
// size determined by operands[reduction_dim] for evaluating the recall. This
// option is useful when the given operand is only a subset of the overall
// computation in SPMD or distributed pipelines, where the true input size
// cannot be deferred by the operand shape.
//
// Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1},
// which contains the approximate top-ks from the input operands. When
Expand All @@ -53,9 +58,10 @@ namespace xla {
XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values, int64_t top_k,
int64_t reduction_dim, const XlaComputation& comparator,
float recall_target = 0.9, bool aggregate_to_topk = true);
float recall_target = 0.9, bool aggregate_to_topk = true,
int64_t reduction_input_size_override = -1);

// Determine the output size of the reduciton dimension. This is useful for jax
// Determine the output size of the reduction dimension. This is useful for jax
// abstract eval to determine the output size.
//
// input_size: Input size of the reduction dimension.
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/python/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ void BuildOpsSubmodule(py::module* m) {
ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"),
py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"),
py::arg("comparator"), py::arg("recall_target") = 0.9,
py::arg("aggregate_to_topk") = true);
py::arg("aggregate_to_topk") = true,
py::arg("reduction_input_size_override") = -1);
ops.def("ApproxTopKReductionOutputSize", &ApproxTopKReductionOutputSize,
py::arg("input_size"), py::arg("rank"), py::arg("top_k"),
py::arg("recall_target"), py::arg("aggregate_to_topk") = true);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes.
_version = 42
_version = 43

xla_platform_names = {
'cpu': 'Host',
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/python/xla_extension/ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def ApproxTopK(
reduction_dim: int,
comparator: XlaComputation,
recall_target: Optional[float],
aggregate_to_topk: Optional[bool]) -> XlaOp: ...
aggregate_to_topk: Optional[bool],
reduction_input_size_override: Optional[int]) -> XlaOp: ...
def ApproxTopKReductionOutputSize(
input_size: int,
rank: int,
Expand Down

0 comments on commit b950ca5

Please sign in to comment.