Skip to content

Commit

Permalink
Added two more test cases to test new utility classes batchrunner.Par…
Browse files Browse the repository at this point in the history
…ameter Product and batchrunner.ParameterSampler. Made a couple of fixes in ParameterSampler as a result.
  • Loading branch information
dcunning11235 committed Feb 25, 2019
1 parent e3d6ec7 commit 0fb8e76
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
8 changes: 6 additions & 2 deletions mesa/batchrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,24 @@ def __next__(self):
class ParameterSampler:
def __init__(self, parameter_lists, n, random_state=None):
self.param_names, self.param_lists = \
zip( *(copy.deepcopy(variable_parameters)).items() )
zip( *(copy.deepcopy(parameter_lists)).items() )
self.n = n
if random_state is None:
self.random_state = random.Random()
elif isinstance(random_state, int):
self.random_state = random.Random(random_state)
else:
self.random_state = random_state
self.count = 0

def __iter__(self):
return self

def __next__(self):
return dict(zip(self.param_names, [self.random_state.choose(l) for l in self.param_lists]))
self.count += 1
if self.count <= self.n:
return dict(zip(self.param_names, [self.random_state.choice(l) for l in self.param_lists]))
raise StopIteration()


class BatchRunner(FixedBatchRunner):
Expand Down
34 changes: 33 additions & 1 deletion tests/test_batchrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mesa import Agent, Model
from mesa.time import BaseScheduler
from mesa.batchrunner import BatchRunner
from mesa.batchrunner import BatchRunner, ParameterProduct, ParameterSampler, FixedBatchRunner


NUM_AGENTS = 7
Expand Down Expand Up @@ -162,6 +162,38 @@ def test_model_with_variable_and_fixed_kwargs(self):
self.assertEqual(model_vars['reported_fixed_param'].iloc[0],
self.fixed_params['fixed_name'])

class TestParameters(unittest.TestCase):
def test_product(self):
params = ParameterProduct({
"var_alpha": ['a', 'b', 'c'],
"var_num": [10, 20]
})

lp = list(params)
print(lp)
self.assertCountEqual(lp, [{'var_alpha': 'a', 'var_num': 10},
{'var_alpha': 'a', 'var_num': 20},
{'var_alpha': 'b', 'var_num': 10},
{'var_alpha': 'b', 'var_num': 20},
{'var_alpha': 'c', 'var_num': 10},
{'var_alpha': 'c', 'var_num': 20}])

def test_sampler(self):
params1 = ParameterSampler({
"var_alpha": ['a', 'b', 'c', 'd', 'e'],
"var_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]},
n=10,
random_state=1)
params2 = ParameterSampler({
"var_alpha": ['a', 'b', 'c', 'd', 'e'],
"var_num": range(16)},
n=10,
random_state=1
)

lp = list(params1)
self.assertEqual(10, len(lp))
self.assertEqual(lp, list(params2))

if __name__ == '__main__':
unittest.main()

0 comments on commit 0fb8e76

Please sign in to comment.