forked from projectmesa/mesa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_examples.py
63 lines (55 loc) · 2.33 KB
/
test_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
import sys
import os.path
import unittest
import contextlib
import importlib
def classcase(name):
return ''.join(x.capitalize() for x in name.replace('-', '_').split('_'))
class TestExamples(unittest.TestCase):
'''
Test examples' models. This creates a model object and iterates it through
some steps. The idea is to get code coverage, rather than to test the
details of each example's model.
'''
EXAMPLES = os.path.abspath(os.path.join(os.path.dirname(__file__), '../examples'))
@contextlib.contextmanager
def active_example_dir(self, example):
'save and restore sys.path and sys.modules'
old_sys_path = sys.path[:]
old_sys_modules = sys.modules.copy()
old_cwd = os.getcwd()
example_path = os.path.abspath(os.path.join(self.EXAMPLES, example))
try:
sys.path.insert(0, example_path)
os.chdir(example_path)
yield
finally:
os.chdir(old_cwd)
added = [m for m in sys.modules.keys() if m not in old_sys_modules]
for mod in added:
del sys.modules[mod]
sys.modules.update(old_sys_modules)
sys.path[:] = old_sys_path
def test_examples(self):
for example in os.listdir(self.EXAMPLES):
if not os.path.isdir(os.path.join(self.EXAMPLES, example)):
continue
if hasattr(self, 'test_{}'.format(example.replace('-', '_'))):
# non-standard example; tested below
continue
print("testing example {!r}".format(example))
with self.active_example_dir(example):
try:
# model.py at the top level
mod = importlib.import_module('model')
server = importlib.import_module('server')
server.server.render_model()
except ImportError:
# <example>/model.py
mod = importlib.import_module('{}.model'.format(example.replace('-', '_')))
server = importlib.import_module('{}.server'.format(example.replace('-', '_')))
server.server.render_model()
Model = getattr(mod, classcase(example))
model = Model()
(model.step() for _ in range(100))