Skip to content

Commit

Permalink
Merge pull request projectmesa#928 from tpike3/master
Browse files Browse the repository at this point in the history
Batchrunner_redux fixes
  • Loading branch information
tpike3 authored Oct 10, 2020
2 parents afd61e8 + 93b3aca commit 3d42c63
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 27 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/intro_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,13 @@ to see the distribution of the agent's wealth. We can get the wealth
values with list comprehension, and then use matplotlib (or another
graphics library) to visualize the data in a histogram.

If you are running from a text editor or IDE, you'll also need to add
this line, to make the graph appear.

.. code:: python
plt.show()
If you are running from a text editor or IDE, you'll also need to add
this line, to make the graph appear.

.. code:: ipython3
# For a jupyter notebook add the following line:
Expand Down
26 changes: 12 additions & 14 deletions mesa/batchrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(

self.display_progress = display_progress

@property
def _make_model_args(self):
"""Prepare all combinations of parameter values for `run_all`
Expand All @@ -128,21 +127,21 @@ def _make_model_args(self):
all_kwargs = []
all_param_values = []

_count = len(self.parameters_list)
if _count:
count = len(self.parameters_list)
if count:
for params in self.parameters_list:
kwargs = params.copy()
kwargs.update(self.fixed_parameters)
all_kwargs.append(kwargs)
all_param_values.append(list(params.values()))

elif len(self.fixed_parameters):
_count = 1
count = 1
kwargs = self.fixed_parameters.copy()
all_kwargs.append(kwargs)
all_param_values.append(list(kwargs.values()))

total_iterations *= _count
total_iterations *= count

return total_iterations, all_kwargs, all_param_values

Expand Down Expand Up @@ -178,7 +177,7 @@ def _make_model_args_mp(self):
def run_all(self):
""" Run the model at all parameter combinations and store results. """
run_count = count()
total_iterations, all_kwargs, all_param_values = self._make_model_args
total_iterations, all_kwargs, all_param_values = self._make_model_args()

with tqdm(total_iterations, disable=not self.display_progress) as pbar:
for i, kwargs in enumerate(all_kwargs):
Expand All @@ -188,8 +187,7 @@ def run_all(self):
pbar.update()

def run_iteration(self, kwargs, param_values, run_count):
kwargs_copy = copy.deepcopy(kwargs)
model = self.model_cls(**kwargs_copy)
model = self.model_cls(**kwargs)
results = self.run_model(model)
if param_values is not None:
model_key = tuple(param_values) + (run_count,)
Expand All @@ -215,7 +213,7 @@ def run_iteration(self, kwargs, param_values, run_count):
getattr(self, "datacollector_agent_reporters", None))

@staticmethod
def run_wrappermp(iter_args):
def _run_wrappermp(iter_args):
"""
Based on requirement of Python multiprocessing requires @staticmethod decorator;
this is primarily to ensure functionality on Windows OS and doe not impact MAC or Linux distros
Expand Down Expand Up @@ -498,7 +496,7 @@ def __init__(self, model_cls, nr_processes=None, **kwargs):
super().__init__(model_cls, **kwargs)
self.pool = Pool(self.processes)

def result_prep_mp(self, results):
def _result_prep_mp(self, results):
"""
Helper Function
:param results: Takes results dictionary from Processpool and single processor debug run and fixes format to
Expand Down Expand Up @@ -539,18 +537,18 @@ def run_all(self):

if self.processes > 1:
with tqdm(total_iterations, disable=not self.display_progress) as pbar:
for params, model in self.pool.imap_unordered(self.run_wrappermp, run_iter_args):
for params, model in self.pool.imap_unordered(self._run_wrappermp, run_iter_args):
results[params] = model
pbar.update()

self.result_prep_mp(results)
self._result_prep_mp(results)
# For debugging model due to difficulty of getting errors during multiprocessing
else:
for run in run_iter_args:
params, model_data = self.run_wrappermp(run)
params, model_data = self._run_wrappermp(run)
results[params] = model_data

self.result_prep_mp(results)
self._result_prep_mp(results)

# Close multi-processing
self.pool.close()
Expand Down
13 changes: 3 additions & 10 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def __init__(self, model_reporters=None, agent_reporters=None, tables=None):
If your model includes a large number of agents, you should *only*
use attribute names for the agent reporter, it will be much faster.
Model reporters can take three types of arguments:
Model reporters can take four types of arguments:
lambda like above:
{"agent_count": lambda m: m.schedule.get_agent_count() }
method with @property decorators
{"agent_count": schedule.get_agent_count()
class attributes of model
{"model_attribute": "model_attribute"}
functions with paramters that have placed in a list
functions with parameters that have placed in a list
{"Model_Function":[function, [param_1, param_2]]}
"""
Expand Down Expand Up @@ -186,14 +186,7 @@ def collect(self, model):
elif isinstance(reporter, list):
self.model_vars[var].append(reporter[0](*reporter[1]))
else:
try:
self.model_vars[var].append(reporter)
except TypeError:
print("Reporters must be in dictionary in one of the following forms: /n\
- key_name: <class attribute> /n \
- key_name: function (e.g schedule.get_agent_count) /n \
- key_name: lambda function /n \
- key_name: [function, [<arguments for function>]")
self.model_vars[var].append(reporter)

if self.agent_reporters:
agent_records = self._record_agents(model)
Expand Down

0 comments on commit 3d42c63

Please sign in to comment.