forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_digraph.py
131 lines (104 loc) · 3.65 KB
/
test_digraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Owner(s): ["oncall: package/deploy"]
from torch.package._digraph import DiGraph
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestDiGraph(PackageTestCase):
"""Test the DiGraph structure we use to represent dependencies in PackageExporter"""
def test_successors(self):
g = DiGraph()
g.add_edge("foo", "bar")
g.add_edge("foo", "baz")
g.add_node("qux")
self.assertIn("bar", list(g.successors("foo")))
self.assertIn("baz", list(g.successors("foo")))
self.assertEqual(len(list(g.successors("qux"))), 0)
def test_predecessors(self):
g = DiGraph()
g.add_edge("foo", "bar")
g.add_edge("foo", "baz")
g.add_node("qux")
self.assertIn("foo", list(g.predecessors("bar")))
self.assertIn("foo", list(g.predecessors("baz")))
self.assertEqual(len(list(g.predecessors("qux"))), 0)
def test_successor_not_in_graph(self):
g = DiGraph()
with self.assertRaises(ValueError):
g.successors("not in graph")
def test_predecessor_not_in_graph(self):
g = DiGraph()
with self.assertRaises(ValueError):
g.predecessors("not in graph")
def test_node_attrs(self):
g = DiGraph()
g.add_node("foo", my_attr=1, other_attr=2)
self.assertEqual(g.nodes["foo"]["my_attr"], 1)
self.assertEqual(g.nodes["foo"]["other_attr"], 2)
def test_node_attr_update(self):
g = DiGraph()
g.add_node("foo", my_attr=1)
self.assertEqual(g.nodes["foo"]["my_attr"], 1)
g.add_node("foo", my_attr="different")
self.assertEqual(g.nodes["foo"]["my_attr"], "different")
def test_edges(self):
g = DiGraph()
g.add_edge(1, 2)
g.add_edge(2, 3)
g.add_edge(1, 3)
g.add_edge(4, 5)
edge_list = list(g.edges)
self.assertEqual(len(edge_list), 4)
self.assertIn((1, 2), edge_list)
self.assertIn((2, 3), edge_list)
self.assertIn((1, 3), edge_list)
self.assertIn((4, 5), edge_list)
def test_iter(self):
g = DiGraph()
g.add_node(1)
g.add_node(2)
g.add_node(3)
nodes = set()
for n in g:
nodes.add(n)
self.assertEqual(nodes, set([1, 2, 3]))
def test_contains(self):
g = DiGraph()
g.add_node("yup")
self.assertTrue("yup" in g)
self.assertFalse("nup" in g)
def test_contains_non_hashable(self):
g = DiGraph()
self.assertFalse([1, 2, 3] in g)
def test_forward_closure(self):
g = DiGraph()
g.add_edge("1", "2")
g.add_edge("2", "3")
g.add_edge("5", "4")
g.add_edge("4", "3")
self.assertTrue(g.forward_transitive_closure("1") == set(["1", "2", "3"]))
self.assertTrue(g.forward_transitive_closure("4") == set(["4", "3"]))
def test_all_paths(self):
g = DiGraph()
g.add_edge("1", "2")
g.add_edge("1", "7")
g.add_edge("7", "8")
g.add_edge("8", "3")
g.add_edge("2", "3")
g.add_edge("5", "4")
g.add_edge("4", "3")
result = g.all_paths("1", "3")
# to get rid of indeterminism
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
expected = {
'"2" -> "3"',
'"1" -> "7"',
'"7" -> "8"',
'"1" -> "2"',
'"8" -> "3"',
}
self.assertEqual(actual, expected)
if __name__ == "__main__":
run_tests()