forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestUtils.cc
165 lines (151 loc) · 4.66 KB
/
TestUtils.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "./TestUtils.h"
#include <gtest/gtest.h>
#include "fbgemm/Fbgemm.h"
namespace fbgemm {
template <typename T>
int compare_validate_buffers(
const T* ref,
const T* test,
int m,
int n,
int ld,
T atol) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
if (std::is_integral<T>::value) {
EXPECT_EQ(test[i * ld + j], ref[i * ld + j])
<< "GEMM results differ at (" << i << ", " << j
<< ") reference: " << (int64_t)ref[i * ld + j]
<< ", FBGEMM: " << (int64_t)test[i * ld + j];
} else {
EXPECT_LE(std::abs(ref[i * ld + j] - test[i * ld + j]), atol)
<< "GEMM results differ at (" << i << ", " << j
<< ") reference: " << ref[i * ld + j]
<< ", FBGEMM: " << test[i * ld + j];
}
}
}
return 0;
}
template int compare_validate_buffers<float>(
const float* ref,
const float* test,
int m,
int n,
int ld,
float atol);
template int compare_validate_buffers<int32_t>(
const int32_t* ref,
const int32_t* test,
int m,
int n,
int ld,
int32_t atol);
template int compare_validate_buffers<uint8_t>(
const uint8_t* ref,
const uint8_t* test,
int m,
int n,
int ld,
uint8_t atol);
template int compare_validate_buffers<int64_t>(
const int64_t* ref,
const int64_t* test,
int m,
int n,
int ld,
int64_t atol);
template <typename T>
bool check_all_zero_entries(const T* test, int m, int n) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
if (test[i * n + j] != 0)
return true;
}
}
return false;
}
template bool check_all_zero_entries<float>(const float* test, int m, int n);
template bool
check_all_zero_entries<int32_t>(const int32_t* test, int m, int n);
template bool
check_all_zero_entries<uint8_t>(const uint8_t* test, int m, int n);
// atol: absolute tolerance. <=0 means do not consider atol.
// rtol: relative tolerance. <=0 means do not consider rtol.
template <>
::testing::AssertionResult floatCloseAll<float, float>(
const std::vector<float>& a,
const std::vector<float>& b,
const float atol,
const float rtol) {
std::stringstream ss;
bool match = true;
if (a.size() != b.size()) {
ss << " size mismatch ";
match = false;
}
if (!match) {
return ::testing::AssertionFailure()
<< " results do not match. " << ss.str();
}
for (size_t i = 0; i < a.size(); i++) {
const bool consider_absDiff = atol > 0;
const bool consider_relDiff = rtol > 0 &&
std::fabs(a[i]) > std::numeric_limits<float>::epsilon() &&
std::fabs(b[i]) > std::numeric_limits<float>::epsilon();
const float absDiff = std::fabs(a[i] - b[i]);
const float relDiff = absDiff / std::fabs(a[i]);
if (consider_absDiff && consider_relDiff) {
match = absDiff <= atol || relDiff <= rtol;
} else if (consider_absDiff) {
match = absDiff <= atol;
} else if (consider_relDiff) {
match = relDiff <= rtol;
}
if (!match) {
ss << " mismatch at (" << i << ") " << std::endl;
ss << "\t ref: " << a[i] << " test: " << b[i] << std::endl;
if (consider_absDiff) {
ss << "\t absolute diff: " << absDiff << " > " << atol << std::endl;
}
if (consider_relDiff) {
ss << "\t relative diff: " << relDiff << " > " << rtol << std::endl;
}
return ::testing::AssertionFailure()
<< " results do not match. " << ss.str();
}
}
return ::testing::AssertionSuccess();
}
template <>
::testing::AssertionResult floatCloseAll<float, float16>(
const std::vector<float>& a,
const std::vector<float16>& b,
const float atol,
const float rtol) {
std::vector<float> b_float(b.size());
const auto transform = [](float16 input) { return cpu_half2float(input); };
std::transform(b.begin(), b.end(), b_float.begin(), transform);
return floatCloseAll(a, b_float, atol, rtol);
}
template <>
::testing::AssertionResult floatCloseAll<float16, float16>(
const std::vector<float16>& a,
const std::vector<float16>& b,
const float atol,
const float rtol) {
std::vector<float> a_float(a.size());
std::vector<float> b_float(b.size());
const auto transform = [](float16 input) { return cpu_half2float(input); };
std::transform(a.begin(), a.end(), a_float.begin(), transform);
std::transform(b.begin(), b.end(), b_float.begin(), transform);
return floatCloseAll(a_float, b_float, atol, rtol);
}
} // namespace fbgemm