forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_vulkan_codegen.py
100 lines (92 loc) · 3.38 KB
/
test_vulkan_codegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import tempfile
import unittest
from tools.gen_vulkan_spv import VulkanShaderGenerator
from yaml.constructor import ConstructorError
class TestVulkanShaderCodegen(unittest.TestCase):
def test_assert_on_duplicate_key_yaml(self) -> None:
yaml_with_duplicate_keys = """
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 1
TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_X: 2
TILE_SIZE_Y: 2
- TILE_SIZE_X: 2
TILE_SIZE_Y: 4
- TILE_SIZE_X: 4
TILE_SIZE_Y: 2
- TILE_SIZE_X: 4
TILE_SIZE_Y: 4
conv2d_pw:
parameter_names_with_default_values:
- TILE_SIZE_X: 1
- TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_X: 2
TILE_SIZE_Y: 2
- TILE_SIZE_X: 2
TILE_SIZE_Y: 4
- TILE_SIZE_X: 4
TILE_SIZE_Y: 2
- TILE_SIZE_X: 4
TILE_SIZE_Y: 4
"""
generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_duplicate_keys)
fp.flush()
with self.assertRaisesRegex(
ConstructorError, r"while constructing a mapping"
):
generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call]
def test_assert_keys_mismatch(self) -> None:
yaml_with_key_mismatch = """
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 1
TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_X: 2
TILE_SIZE_Z: 2
"""
generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_key_mismatch)
fp.flush()
with self.assertRaisesRegex(KeyError, r"Invalid keys {'TILE_SIZE_Z'}"):
generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call]
def test_missing_key_default_val(self) -> None:
yaml_with_key_mismatch = """
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 1
TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_Y: 2
"""
file_content = """
x = $TILE_SIZE_X + $TILE_SIZE_Y
"""
generator = VulkanShaderGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_key_mismatch)
fp.flush()
generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call]
with tempfile.TemporaryDirectory() as tmp_dir:
template_file_name = os.path.join(tmp_dir, "conv2d_pw.glslt")
with open(template_file_name, "w") as template_file:
template_file.write(file_content)
template_file.flush()
generator.generate(template_file.name, tmp_dir) # type: ignore[no-untyped-call]
file_name_1 = os.path.join(tmp_dir, "conv2d_pw_1x1.glsl")
file_name_2 = os.path.join(tmp_dir, "conv2d_pw_1x2.glsl")
self.assertTrue(os.path.exists(file_name_1))
self.assertTrue(os.path.exists(file_name_2))
with open(file_name_1, "r") as f:
contents = f.read()
self.assertTrue("1 + 1" in contents)
with open(file_name_2, "r") as f:
contents = f.read()
self.assertTrue("1 + 2" in contents)