Skip to content

Commit

Permalink
Updates collapseDims() function and documentation (#7056)
Browse files Browse the repository at this point in the history
* Updates collapseDims() function and documentation

* Adds C++ tests, validates input, updates names for readability

* Removes invalid test

* stashing to merge AT_CHECK macro

* Updates asserts, removes tests on Windows
  • Loading branch information
mruberry authored and ezyang committed May 13, 2018
1 parent cfc1d92 commit 37b9d09
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 280 deletions.
204 changes: 63 additions & 141 deletions aten/src/ATen/cuda/detail/TensorInfo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@ struct TensorInfo {
// slice)
void reduceDim(int dim);

// Collapses all runs of successive dimensions if the size/strides
// match up within the run and there are no holes between the
// dimensions.
// If excludeDim is set (not -1), then excludeDim will not be
// collapsed with any other dimension.
// Function returns the new dimension index that excludeDim maps to,
// since the collapsed dimensions are <= the input dimensions.
int collapseDims(int excludeDim = -1);
/*
Updates the TensorInfo's dims, sizes, and strides to reflect a "collapse" of
the info, possibly excluding the optional excludeDim. A "collapsed" version
of the info is the fewest dims that order the tensor's elements in the same
way as the original info. If excludeDim is specified, the collapse is the
fewest dims that order the tensor's elements as the original and preserve the
excluded dimension, unless the tensor collapses to a point.
Returns the (new) index of the preserved dimension if excludeDim is
specified. Returns 0 if the tensor is collapsed to a point. Returns -1
otherwise.
*/
int collapseDims(const int excludeDim = -1);

// Contiguous tensors of more than one dimension are collapsed down
// to one tensor
Expand All @@ -49,7 +54,7 @@ TensorInfo<T, IndexType>::TensorInfo(T* p,
IndexType st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
assert(dims < MAX_TENSORINFO_DIMS);
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);

for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
Expand All @@ -60,163 +65,80 @@ TensorInfo<T, IndexType>::TensorInfo(T* p,
template <typename T, typename IndexType>
void
TensorInfo<T, IndexType>::reduceDim(int dim) {
assert(dim < dims && dim >= 0);
AT_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
sizes[dim] = 1;
}

template <typename T, typename IndexType>
int
TensorInfo<T, IndexType>::collapseDims(int excludeDim) {
// Find the innermost dimension not of size 1, since dimensions of size 1 are
// collapsible.
int firstNonOneDim = -1;
TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {

for (int i = dims - 1; i >= 0; --i) {
if (i == excludeDim) {
// We cannot collapse this dimension, even if it is size 1
firstNonOneDim = i;
break;
}

if (sizes[i] != 1) {
firstNonOneDim = i;
break;
}
}

// Special case: if all dimensions are of size 1, then this is a
// single-point tensor that we still have to operate on. Reduce to a
// single point.
if (firstNonOneDim == -1) {
assert(excludeDim == -1);

dims = 1;
sizes[0] = 1;
strides[0] = 1;

// Everything effectively got collapsed into this dimension
return 0;
}

// Count the number of successive dimensions that can be collapsed, from
// innermost to outermost.
int numCollapsed = 0;

// Skip the leading size 1 dims
numCollapsed += dims - 1 - firstNonOneDim;

// We perform one pass through to determine how many dimensions we
// can collapse, before calculating the actual size of the collapsed
// dimensions.
// size/strideInner are the size/strides of the previous inner
// non-collapsible dim we encounter.
int64_t sizeInner = sizes[firstNonOneDim];
int64_t strideInner = strides[firstNonOneDim];

for (int i = firstNonOneDim - 1; i >= 0; --i) {
int64_t sizeOuter = sizes[i];
int64_t strideOuter = strides[i];

// Don't collapse this dimension if we want to exclude it from
// collapsing.
// Since this code is attempting to collapse a subsequent
// dimension (i) with the preceding dimension (i + 1), we can only
// perform collapsing if the preceding dimension can be collapsed
// (i.e., not excludeDim)
if ((excludeDim != i) && (excludeDim != i + 1)) {
// The next outermost dimension can be skipped if size 1
if (sizeOuter == 1) {
++numCollapsed;
continue;
}
AT_CHECK(excludeDim >= -1 && excludeDim < dims,
"expected excluded dim between -1 and dims - 1");

// If the next outermost dimension is contiguous with the
// previous non-collapsed one, collapse it
if (strideOuter == strideInner * sizeInner) {
++numCollapsed;
int stopDim = (excludeDim == -1) ? dims : excludeDim;
int newIndex = -1;
int oldIndex = 0;
int remappedExcludedDim = -1;

// This is the run of collapsed dimensions' size
sizeInner = sizeInner * sizeOuter;
while (oldIndex < dims) {
// Finds a dimension to collapse into
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}

++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
++oldIndex;
break;
}

// Otherwise, this new outer dimension at `i` cannot be collapsed
// because it is excluded from collapsing, or it is not contiguous
// with the previous inner dimension.
sizeInner = sizeOuter;
strideInner = strideOuter;
}

// This will be our new size/stride and dimension.
IndexType newSizes[MAX_TENSORINFO_DIMS];
IndexType newStrides[MAX_TENSORINFO_DIMS];

assert(numCollapsed < dims);
int newDims = dims - numCollapsed;

// We return the index of the excluded dimension that is excluded
// from being collapsed here.
int returnDim = -1;

// We perform a second pass through the dimensions to actually
// calculate the size of the collapsed dimensions.
int collapsedIndex = dims - numCollapsed - 1;
newSizes[collapsedIndex] = sizes[firstNonOneDim];
newStrides[collapsedIndex] = strides[firstNonOneDim];

if (firstNonOneDim == excludeDim) {
returnDim = collapsedIndex;
}

for (int i = firstNonOneDim - 1; i >= 0; --i) {
IndexType sizeOuter = sizes[i];
IndexType strideOuter = strides[i];

if ((excludeDim != i) && (excludeDim != i + 1)) {
if (sizeOuter == 1) {
// skip
// Collapses dims
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}

if (strideOuter == newSizes[collapsedIndex] * newStrides[collapsedIndex]) {
// collapse
newSizes[collapsedIndex] *= sizeOuter;
continue;

if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
sizes[newIndex] *= sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
} else {
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
}
}

// Otherwise, strides don't match, or dim `i` is excluded from
// collapsing.
--collapsedIndex;
assert(collapsedIndex >= 0);
assert(collapsedIndex < newDims);
newSizes[collapsedIndex] = sizeOuter;
newStrides[collapsedIndex] = strideOuter;

if (excludeDim == i) {
returnDim = collapsedIndex;
// Handles excludeDim being set (oldIndex == excludeDim)
if (oldIndex != dims) {

// Preserves excluded dimension
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
remappedExcludedDim = newIndex;

// Restarts iteration after excludeDim
++oldIndex;
stopDim = dims;
}
}

// We must have filled all the dimensions we're looking for
assert(collapsedIndex == 0);
assert((excludeDim == -1) || (returnDim != -1));

dims = newDims;
// Handles special case of all dims size 1
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
dims = 1;
sizes[0] = 1;
strides[0] = 1;

for (int i = 0; i < dims; ++i) {
sizes[i] = newSizes[i];
strides[i] = newStrides[i];
return 0;
}

// After collapsing, the original `excludeDim` may have been
// renumbered to this new `returnDim`, since some dimensions could
// have been collapsed.
return returnDim;
dims = newIndex + 1;
return remappedExcludedDim;
}


// Translate a linear index for the apply to a T* offset;
// specialized on `Dims` to reduce nvcc compilation time
template <typename T, typename IndexType, int Dims>
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ if(NOT NO_CUDA)
target_link_libraries(cuda_rng_test ATen_cpu ATen_cuda_library)
endif()

if(NOT NO_CUDA)
cuda_add_executable(apply_test apply_test.cpp)
target_link_libraries(apply_test ATen)
endif()

if (CUDNN_FOUND)
add_executable(cudnn_test cudnn_test.cpp)
target_link_libraries(cudnn_test ATen_cpu ATen_cuda_library)
Expand Down
121 changes: 121 additions & 0 deletions aten/src/ATen/test/apply_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"

#include "cuda.h"
#include "cuda_runtime.h"

#include "ATen/cuda/detail/TensorInfo.cuh"

/*
Tests related to tensor indexing and applying operations.
*/
#ifndef _WIN32

TEST_CASE("2D Contiguous", "Collapses a 2D contiguous tensor to 1D contiguous") {
int sizes[] = {4, 4};
int strides[] = {4, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == (4 * 4));
}

TEST_CASE("3D Contiguous", "Collapses a 3D contiguous tensor to a 1D contiguous") {
int sizes[] = {6, 3, 7};
int strides[] = {3 * 7, 7, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == (6 * 3 * 7));
}

TEST_CASE("3D Partial Collapse", "Collapses a 3D noncontiguous tensor to a 2D tensor") {
int sizes[] = {4, 3, 2};
int strides[] = {3 * 3, 3, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 2);
REQUIRE(ti.sizes[0] == (4 * 3));
REQUIRE(ti.sizes[1] == 2);
}

TEST_CASE("2D Strided Collapse", "Collapses a 2D skip contiguous tensor to a 1D skip contiguous tensor") {
int sizes[] = {3, 2};
int strides[] = {2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == (3 * 2));
REQUIRE(ti.strides[0] == 2);
}

TEST_CASE("4D Partial Strided Collapse", "Collapses a 4D tensor to a 2D tensor"){
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 2);
REQUIRE(ti.sizes[0] == (3 * 6));
REQUIRE(ti.strides[0] == 22);
REQUIRE(ti.sizes[1] == (5 * 2));
REQUIRE(ti.strides[1] == 2);
}

TEST_CASE("Collapsing Zeros and Ones", "Collapses a 5D tensor to a 1D tensor") {
int sizes[] = {1, 10, 1, 5, 4};
int strides[] = {4, 0, 16, 0, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 5, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 2);
REQUIRE(ti.sizes[0] == (10 * 5));
REQUIRE(ti.strides[0] == 0);
REQUIRE(ti.sizes[1] == 4);
REQUIRE(ti.strides[1] == 1);
}

TEST_CASE("Collapsing to a Point Tensor", "Collapses a 3D tensor to a point tensor") {
int sizes[] = {1, 1, 1};
int strides[] = {17, 12, 3};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
REQUIRE(ti.collapseDims() == 0);
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == 1);
REQUIRE(ti.strides[0] == 1);
}

TEST_CASE("Excluding in a 4D Contiguous", "Collapses a 4D tensor to a 3D tensor") {
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
REQUIRE(ti.collapseDims(1) == 1);
REQUIRE(ti.dims == 3);
REQUIRE(ti.sizes[0] == 3);
REQUIRE(ti.strides[0] == (6 * 22));
REQUIRE(ti.sizes[1] == 6);
REQUIRE(ti.strides[1] == 22);
REQUIRE(ti.sizes[2] == (5 * 2));
REQUIRE(ti.strides[2] == 2);
}

TEST_CASE("Roving Exclusion", "Collapses a 4D tensor to a 3D tensor") {
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
REQUIRE(ti.collapseDims(2) == 1);
REQUIRE(ti.dims == 3);
REQUIRE(ti.sizes[0] == (3 * 6));
REQUIRE(ti.strides[0] == 22);
REQUIRE(ti.sizes[1] == 5);
REQUIRE(ti.strides[1] == 4);
REQUIRE(ti.sizes[2] == 2);
REQUIRE(ti.strides[2] == 2);
}

TEST_CASE("Invalid Exclusion", "Attempts to exclude a nonexisting dimension") {
int sizes[] = {1, 1, 1};
int strides[] = {17, 12, 3};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
REQUIRE_THROWS(ti.collapseDims(5));
}

#endif
Loading

0 comments on commit 37b9d09

Please sign in to comment.