Skip to content

Commit 174303c

Browse files
AzeezIshAzeezIsh
and
AzeezIsh
authored
Inclusive Testing (#48)
* Fixed several errors with calls and tests. The inclusive_scan_operations.py file was not calling the correct function, had to add that in with the correct c pointer calls. Dove into the c++ files for the dim assertion error, realized that was just an issue with the c++ file, fixed that as well. The test_inclusive.py file didn't have F16 coverage, checked for that as well. * Adhered to checkstyle requirements. --------- Co-authored-by: AzeezIsh <[email protected]>
1 parent 9b3fb65 commit 174303c

File tree

2 files changed

+229
-1
lines changed

2 files changed

+229
-1
lines changed

arrayfire_wrapper/lib/vector_algorithms/inclusive_scan_operations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def scan_by_key(key: AFArray, arr: AFArray, dim: int, op: BinaryOperator, inclus
2828
source: https://arrayfire.org/docs/group__scan__func__scanbykey.htm#gaaae150e0f197782782f45340d137b027
2929
"""
3030
out = AFArray.create_null_pointer()
31-
call_from_clib(scan.__name__, ctypes.pointer(out), key, arr, dim, op.value, inclusive_scan)
31+
call_from_clib(scan_by_key.__name__, ctypes.pointer(out), key, arr, dim, op.value, inclusive_scan)
3232
return out
3333

3434

tests/test_inclusive.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import pytest
2+
3+
import arrayfire_wrapper.dtypes as dtype
4+
import arrayfire_wrapper.lib as wrapper
5+
from tests.utility_functions import check_type_supported, get_all_types, get_real_types
6+
7+
8+
@pytest.mark.parametrize(
9+
"shape",
10+
[
11+
(),
12+
(3,),
13+
(3, 3),
14+
(3, 3, 3),
15+
(3, 3, 3, 3),
16+
],
17+
)
18+
@pytest.mark.parametrize("dtype_name", get_all_types())
19+
def test_accum_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
20+
"""Test accumulate operation across all supported data types."""
21+
check_type_supported(dtype_name)
22+
if dtype_name == dtype.f16:
23+
pytest.skip()
24+
values = wrapper.randu(shape, dtype_name)
25+
result = wrapper.accum(values, 0)
26+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
27+
28+
29+
@pytest.mark.parametrize(
30+
"dim",
31+
[
32+
0,
33+
1,
34+
2,
35+
3,
36+
],
37+
)
38+
def test_accum_dims(dim: int) -> None:
39+
"""Test accumulate dimensions operation"""
40+
shape = (3, 3)
41+
values = wrapper.randu(shape, dtype.f32)
42+
result = wrapper.accum(values, dim)
43+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
44+
45+
46+
@pytest.mark.parametrize(
47+
"invdim",
48+
[
49+
-1,
50+
5,
51+
],
52+
)
53+
def test_accum_invdims(invdim: int) -> None:
54+
"""Test accumulate invalid dimensions operation"""
55+
with pytest.raises(RuntimeError):
56+
shape = (3, 3)
57+
values = wrapper.randu(shape, dtype.f32)
58+
result = wrapper.accum(values, invdim)
59+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
60+
61+
62+
@pytest.mark.parametrize(
63+
"shape",
64+
[
65+
(),
66+
(3,),
67+
(3, 3),
68+
(3, 3, 3),
69+
(3, 3, 3, 3),
70+
],
71+
)
72+
@pytest.mark.parametrize("dtype_name", get_all_types())
73+
def test_scan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
74+
"""Test scan operation across all supported data types."""
75+
check_type_supported(dtype_name)
76+
if dtype_name == dtype.f16:
77+
pytest.skip()
78+
values = wrapper.randu(shape, dtype_name)
79+
result = wrapper.scan(values, 0, wrapper.BinaryOperator.ADD, True)
80+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}, dtype {dtype_name}" # noqa
81+
82+
83+
@pytest.mark.parametrize(
84+
"dim",
85+
[
86+
0,
87+
1,
88+
2,
89+
3,
90+
],
91+
)
92+
def test_scan_dims(dim: int) -> None:
93+
"""Test scan dimensions operation"""
94+
shape = (3, 3)
95+
values = wrapper.randu(shape, dtype.f32)
96+
result = wrapper.scan(values, dim, wrapper.BinaryOperator.ADD, True)
97+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for dimension: {dim}" # noqa
98+
99+
100+
@pytest.mark.parametrize(
101+
"invdim",
102+
[
103+
-1,
104+
5,
105+
],
106+
)
107+
def test_scan_invdims(invdim: int) -> None:
108+
"""Test scan invalid dimensions operation"""
109+
with pytest.raises(RuntimeError):
110+
shape = (3, 3)
111+
values = wrapper.randu(shape, dtype.f32)
112+
result = wrapper.scan(values, invdim, wrapper.BinaryOperator.ADD, True)
113+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
114+
115+
116+
@pytest.mark.parametrize(
117+
"binaryOp",
118+
[
119+
0,
120+
1,
121+
2,
122+
3,
123+
],
124+
)
125+
def test_scan_binaryOp(binaryOp: int) -> None:
126+
"""Test scan dimensions operation"""
127+
shape = (3, 3)
128+
values = wrapper.randu(shape, dtype.f32)
129+
result = wrapper.scan(values, 0, wrapper.BinaryOperator(binaryOp), True)
130+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for operation: {binaryOp}" # noqa
131+
132+
133+
@pytest.mark.parametrize(
134+
"shape",
135+
[
136+
(),
137+
(3,),
138+
(3, 3),
139+
(3, 3, 3),
140+
(3, 3, 3, 3),
141+
],
142+
)
143+
@pytest.mark.parametrize("dtype_name", get_real_types())
144+
def test_scan_by_key_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
145+
"""Test scan_by_key operation across all supported data types."""
146+
check_type_supported(dtype_name)
147+
if (
148+
dtype_name == dtype.f16
149+
or dtype_name == dtype.f32
150+
or dtype_name == dtype.uint16
151+
or dtype_name == dtype.uint8
152+
or dtype_name == dtype.int16
153+
):
154+
pytest.skip()
155+
values = wrapper.randu(shape, dtype_name)
156+
result = wrapper.scan_by_key(values, values, 0, wrapper.BinaryOperator.ADD, True)
157+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}, dtype {dtype_name}" # noqa
158+
159+
160+
@pytest.mark.parametrize(
161+
"dim",
162+
[
163+
0,
164+
1,
165+
2,
166+
3,
167+
],
168+
)
169+
def test_scan_by_key_dims(dim: int) -> None:
170+
"""Test scan_by_key dimensions operation"""
171+
shape = (3, 3)
172+
values = wrapper.randu(shape, dtype.int32)
173+
result = wrapper.scan_by_key(values, values, dim, wrapper.BinaryOperator.ADD, True)
174+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for dimension: {dim}" # noqa
175+
176+
177+
@pytest.mark.parametrize(
178+
"invdim",
179+
[
180+
-1,
181+
5,
182+
],
183+
)
184+
def test_scan_by_key_invdims(invdim: int) -> None:
185+
"""Test scan_by_key invalid dimensions operation"""
186+
with pytest.raises(RuntimeError):
187+
shape = (3, 3)
188+
values = wrapper.randu(shape, dtype.int32)
189+
result = wrapper.scan_by_key(values, values, invdim, wrapper.BinaryOperator.ADD, True)
190+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
191+
192+
193+
@pytest.mark.parametrize(
194+
"binaryOp",
195+
[
196+
0,
197+
1,
198+
2,
199+
3,
200+
],
201+
)
202+
def test_scan_by_key_binaryOp(binaryOp: int) -> None:
203+
"""Test scan_by_key dimensions operation"""
204+
shape = (3, 3)
205+
values = wrapper.randu(shape, dtype.int32)
206+
result = wrapper.scan_by_key(values, values, 0, wrapper.BinaryOperator(binaryOp), True)
207+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for operation: {binaryOp}" # noqa
208+
209+
210+
@pytest.mark.parametrize(
211+
"shape",
212+
[
213+
(),
214+
(3,),
215+
(3, 3),
216+
(3, 3, 3),
217+
(3, 3, 3, 3),
218+
],
219+
)
220+
@pytest.mark.parametrize("dtype_name", get_real_types())
221+
def test_where_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
222+
"""Test where operation across all supported data types."""
223+
check_type_supported(dtype_name)
224+
if dtype_name == dtype.f16:
225+
pytest.skip()
226+
values = wrapper.randu(shape, dtype_name)
227+
result = wrapper.where(values)
228+
assert wrapper.get_dims(result)[0] == 3 ** len(shape), f"failed for shape: {shape}, dtype {dtype_name}" # noqa

0 commit comments

Comments
 (0)