forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaccumulate.h
134 lines (117 loc) · 4.11 KB
/
accumulate.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include <c10/util/ArrayRef.h>
#include <iterator>
#include <numeric>
#include <type_traits>
namespace c10 {
/// Sum of a list of integers; accumulates into the int64_t datatype
template <
typename C,
typename std::enable_if<
std::is_integral<typename C::value_type>::value,
int>::type = 0>
inline int64_t sum_integers(const C& container) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(
container.begin(), container.end(), static_cast<int64_t>(0));
}
/// Sum of integer elements referred to by iterators; accumulates into the
/// int64_t datatype
template <
typename Iter,
typename std::enable_if<
std::is_integral<
typename std::iterator_traits<Iter>::value_type>::value,
int>::type = 0>
inline int64_t sum_integers(Iter begin, Iter end) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(begin, end, static_cast<int64_t>(0));
}
/// Product of a list of integers; accumulates into the int64_t datatype
template <
typename C,
typename std::enable_if<
std::is_integral<typename C::value_type>::value,
int>::type = 0>
inline int64_t multiply_integers(const C& container) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(
container.begin(),
container.end(),
static_cast<int64_t>(1),
std::multiplies<>());
}
/// Product of integer elements referred to by iterators; accumulates into the
/// int64_t datatype
template <
typename Iter,
typename std::enable_if<
std::is_integral<
typename std::iterator_traits<Iter>::value_type>::value,
int>::type = 0>
inline int64_t multiply_integers(Iter begin, Iter end) {
// std::accumulate infers return type from `init` type, so if the `init` type
// is not large enough to hold the result, computation can overflow. We use
// `int64_t` here to avoid this.
return std::accumulate(
begin, end, static_cast<int64_t>(1), std::multiplies<>());
}
/// Return product of all dimensions starting from k
/// Returns 1 if k>=dims.size()
template <
typename C,
typename std::enable_if<
std::is_integral<typename C::value_type>::value,
int>::type = 0>
inline int64_t numelements_from_dim(const int k, const C& dims) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
if (k > static_cast<int>(dims.size())) {
return 1;
} else {
auto cbegin = dims.cbegin();
std::advance(cbegin, k);
return multiply_integers(cbegin, dims.cend());
}
}
/// Product of all dims up to k (not including dims[k])
/// Throws an error if k>dims.size()
template <
typename C,
typename std::enable_if<
std::is_integral<typename C::value_type>::value,
int>::type = 0>
inline int64_t numelements_to_dim(const int k, const C& dims) {
TORCH_INTERNAL_ASSERT(0 <= k);
TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
auto cend = dims.cbegin();
std::advance(cend, k);
return multiply_integers(dims.cbegin(), cend);
}
/// Product of all dims between k and l (including dims[k] and excluding
/// dims[l]) k and l may be supplied in either order
template <
typename C,
typename std::enable_if<
std::is_integral<typename C::value_type>::value,
int>::type = 0>
inline int64_t numelements_between_dim(int k, int l, const C& dims) {
TORCH_INTERNAL_ASSERT(0 <= k);
TORCH_INTERNAL_ASSERT(0 <= l);
if (k > l) {
std::swap(k, l);
}
TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
auto cbegin = dims.cbegin();
auto cend = dims.cbegin();
std::advance(cbegin, k);
std::advance(cend, l);
return multiply_integers(cbegin, cend);
}
} // namespace c10