Skip to content

Commit

Permalink
Add Tensor argmin and argmax + min/max that return indices (flashligh…
Browse files Browse the repository at this point in the history
…t#732)

Summary:
Pull Request resolved: flashlight#732

See title - mirror numpy's [`argmax`](https://numpy.org/doc/stable/reference/generated/numpy.argmax.html) and [`argmin`](https://numpy.org/doc/stable/reference/generated/numpy.argmin.html) per naming

Also mirrors Torch/adds [`min`](https://pytorch.org/docs/stable/generated/torch.min.html) and [`max`](https://pytorch.org/docs/stable/generated/torch.max.html) funcs that return both values and indices.

Requires an axis argument - can add an overload later - one can use `Tensor::flat()` to perform the op over the entire tensor.

Reviewed By: benoitsteiner

Differential Revision: D30545502

fbshipit-source-id: 415755c46237d81028232ba398b8bd0302764f4a
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Aug 26, 2021
1 parent 91c8c90 commit ee05861
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 20 deletions.
7 changes: 5 additions & 2 deletions flashlight/fl/tensor/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ range::range(idx start, idx end, Dim stride)
start_(std::visit([](Dim idx) -> Dim { return idx; }, start)),
// fl::end --> -1, else idx as Dim
end_(
std::holds_alternative<fl::end_t>(end) ? std::get<fl::end_t>(end)
: std::get<Dim>(end) - 1),
std::holds_alternative<fl::end_t>(end)
? std::get<fl::end_t>(end)
// If start == end, set start_ == end_, else end_ = end - 1
: (std::get<Dim>(end) == start_ ? start_
: std::get<Dim>(end) - 1)),
stride_(stride) {}

Dim range::start() const {
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace fl {
Shape::Shape(std::vector<Dim> d) : dims_(std::move(d)) {}
Shape::Shape(std::initializer_list<Dim> d) : Shape(std::vector<Dim>(d)) {}

size_t Shape::elements() const {
Dim Shape::elements() const {
if (dims_.size() == 0) {
return 0;
}
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Shape {
/**
* @return the number of elements in a tensor that has the given shape.
*/
size_t elements() const;
Dim elements() const;

/**
* @return Number of dimensions in the shape.
Expand Down
14 changes: 14 additions & 0 deletions flashlight/fl/tensor/TensorBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,23 @@ class TensorBackend {
virtual Tensor
amax(const Tensor& input, const std::vector<int>& axes, bool keepDims) = 0;
virtual double amax(const Tensor& input) = 0; // TODO: consoildate w/ above
virtual void min(
Tensor& values,
Tensor& indices,
const Tensor& input,
const unsigned axis,
bool keepDims) = 0;
virtual void max(
Tensor& values,
Tensor& indices,
const Tensor& input,
const unsigned axis,
bool keepDims) = 0;
virtual Tensor
sum(const Tensor& input, const std::vector<int>& axes, bool keepDims) = 0;
virtual double sum(const Tensor& input) = 0; // TODO: consolidate w/ above
virtual Tensor argmax(const Tensor& input, unsigned axis, bool keepDims) = 0;
virtual Tensor argmin(const Tensor& input, unsigned axis, bool keepDims) = 0;
virtual Tensor
mean(const Tensor& input, const std::vector<int>& axes, bool keepDims) = 0;
virtual double mean(const Tensor& input) = 0; // TODO: consolidate w/ above
Expand Down
77 changes: 65 additions & 12 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,32 +629,74 @@ Tensor matmul(

/************************** Reductions ***************************/

Tensor amin(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor amin(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().amin(input, axes, keepDims);
}

Tensor amax(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor amax(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().amax(input, axes, keepDims);
}

Tensor sum(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
void min(
Tensor& values,
Tensor& indices,
const Tensor& input,
const unsigned axis,
bool keepDims) {
FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input);
return input.backend().min(values, indices, input, axis, keepDims);
}

void max(
Tensor& values,
Tensor& indices,
const Tensor& input,
const unsigned axis,
bool keepDims /* = false */) {
FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input);
return input.backend().max(values, indices, input, axis, keepDims);
}

Tensor sum(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().sum(input, axes, keepDims);
}

Tensor mean(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor argmax(const Tensor& input, unsigned axis, bool keepDims /* = false */) {
return input.backend().argmax(input, axis, keepDims);
}

Tensor argmin(const Tensor& input, unsigned axis, bool keepDims /* = false */) {
return input.backend().argmin(input, axis, keepDims);
}

Tensor mean(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().mean(input, axes, keepDims);
}

Tensor
median(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor median(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().median(input, axes, keepDims);
}

Tensor var(
const Tensor& input,
const std::vector<int>& axes,
const bool bias,
bool keepDims) {
bool keepDims /* = false */) {
return input.backend().var(input, axes, bias, keepDims);
}

Expand All @@ -679,28 +721,39 @@ GENERATE_VAR(short);
GENERATE_VAR(unsigned short);
#undef GENERATE_VAR

Tensor std(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor std(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().std(input, axes, keepDims);
}

double norm(const Tensor& input) {
return input.backend().norm(input);
}

Tensor
countNonzero(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor countNonzero(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().countNonzero(input, axes, keepDims);
}

Tensor any(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor any(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().any(input, axes, keepDims);
}

bool any(const Tensor& input) {
return input.backend().any(input);
}

Tensor all(const Tensor& input, const std::vector<int>& axes, bool keepDims) {
Tensor all(
const Tensor& input,
const std::vector<int>& axes,
bool keepDims /* = false */) {
return input.backend().all(input, axes, keepDims);
}

Expand Down
87 changes: 83 additions & 4 deletions flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,8 @@ Tensor matmul(
*
* @param[in] input the input along which to operate
* @param[in] dim the dimension along which to reduce.
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the minimum values
*/
Tensor
Expand All @@ -1014,8 +1016,9 @@ T amin(const Tensor& input);
*
* @param[in] input the input along which to operate
* @param[in] dim the dimension along which to reduce.
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the max
*
*/
Tensor
amax(const Tensor& input, const std::vector<int>& axes, bool keepDims = false);
Expand All @@ -1030,11 +1033,75 @@ amax(const Tensor& input, const std::vector<int>& axes, bool keepDims = false);
template <typename T>
T amax(const Tensor& input);

/**
* Compute the maximum value along multiple axes for a tensor, returning both
* the maximum values and the indices of the input tensor in which they appear.
*
* @param[out] values a Tensor into which to populate the max values from the
* tensor along the specified axes
* @param[out] indices a Tensor into which to populate the indices of the max
* values from the tensor along the specified axes
* @param[in] input the input tensor
* @param[in] axis the axis along which to find minimum values
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
*/
void min(
Tensor& values,
Tensor& indices,
const Tensor& input,
const unsigned axis,
bool keepDims = false);

/**
* Compute the maximum value along multiple axes for a tensor, returning both
* the maximum values and the indices of the input tensor in which they appear.
*
* @param[out] values a Tensor into which to populate the max values from the
* tensor along the specified axes
* @param[out] indices a Tensor into which to populate the indices of the max
* values from the tensor along the specified axes
* @param[in] input the input tensor
* @param[in] axis the axis along which to find maximum values
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
*/
void max(
Tensor& values,
Tensor& indices,
const Tensor& input,
const unsigned axis,
bool keepDims = false);

/**
* Return the indices of the maximum values along an axis.
*
* @param[in] input the input tensor
* @param[in] axis the axis along which to find maximum values
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the indices of the max values along each axis
*/
Tensor argmax(const Tensor& input, unsigned axis, bool keepDims = false);

/**
* Return the indices of the minimum values along an axis.
*
* @param[in] input the input tensor
* @param[in] axis the axis along which to find minimum values
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the indices of the max values along each axis
*/
Tensor argmin(const Tensor& input, unsigned axis, bool keepDims = false);

/**
* Sum of tensor over given axes.
*
* @param[in] input the input along which to operate
* @param[in] axes the dimension along which to reduce.
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the sum across given axes
*/
Tensor
Expand All @@ -1055,6 +1122,8 @@ T sum(const Tensor& input);
*
* @param[in] input the input along which to operate
* @param[in] axes the dimension along which to reduce.
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the mean across given axes
*/
Tensor
Expand All @@ -1075,6 +1144,8 @@ T mean(const Tensor& input);
*
* @param[in] input the input along which to operate
* @param[in] axes the dimension along which to reduce.
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the median across given axes
*/
Tensor median(
Expand All @@ -1098,6 +1169,9 @@ T median(const Tensor& input);
*
* @param[in] input the input along which to operate
* @param[in] axes the dimension along which to reduce.
* @param[in] bias defaults false. Compute biased or unbiased variance
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the var across given axes
*/
Tensor var(
Expand All @@ -1121,6 +1195,8 @@ T var(const Tensor& input, const bool bias = false);
*
* @param[in] input the input along which to operate
* @param[in] axes the dimension along which to reduce.
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the var across given axes
*/
Tensor
Expand All @@ -1143,7 +1219,8 @@ double norm(const Tensor& input);
*
* @param[in] input the tensor on which to operate.
* @param[in] dims (optional) the axis along which to give nonzeros.
*
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a tensor containing the number of nonzero elements along each axis or
* over the entire tensor.
*/
Expand All @@ -1159,7 +1236,8 @@ Tensor countNonzero(
*
* @param[in] input the input tensor
* @param[in] axes the axes along which to check for truthy values
*
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a bool tensor containing axis-wise values denoting truthy values
* along that axis in the input tensor.
*/
Expand All @@ -1184,7 +1262,8 @@ bool any(const Tensor& input);
*
* @param[in] input the input tensor
* @param[in] axes the axes along which to
*
* @param[in] keepDims defaults false. Keeps the dimensions being reduced over
* as singleton dimensions rather than collapsing them
* @return a bool tensor containing axis-wise values with true along
* axes that contain only true values.
*/
Expand Down
Loading

0 comments on commit ee05861

Please sign in to comment.