Skip to content

Commit

Permalink
Revert to a sane approach
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince committed Sep 11, 2018
1 parent 8118107 commit 920c9ca
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 38 deletions.
57 changes: 21 additions & 36 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
* For collecting agent-level variables, agents must have a unique_id
"""
from collections import defaultdict
import itertools
from operator import attrgetter
import pandas as pd

Expand Down Expand Up @@ -136,34 +136,24 @@ def _new_table(self, table_name, table_columns):
new_table = {column: [] for column in table_columns}
self.tables[table_name] = new_table

@staticmethod
def _repgetter(reporters):
""" Get reports from a list of agents.
Args:
reporters: List of reporter functions.
"""
if all([rep.__name__ == 'attr_collector' for rep in reporters]):
# Fast path if all reporters are attribute collecters
def report(agent):
f = attrgetter(*[rep.attribute_name for rep in reporters])
return (agent.unique_id, ) + f(agent)
else:
def report(agent):
return (agent.unique_id, ) + tuple(
get_report(rep, agent) for rep in reporters)
return report

def collect(self, model):
""" Collect all the data for the given model object. """
if self.model_reporters:
for var, reporter in self.model_reporters.items():
self.model_vars[var].append(reporter(model))

if self.agent_reporters:
f = self._repgetter(self.agent_reporters.values())
agent_records = map(f, model.schedule.agents)
rep_funcs = self.agent_reporters.values()
if all([rep.__name__ == 'attr_collector' for rep in rep_funcs]):
attributes = [func.attribute_name for func in rep_funcs]
get_reports = attrgetter(
*['model.schedule.steps', 'unique_id'] + attributes)
else:
def get_reports(agent):
return (
agent.model.schedule.steps, agent.unique_id) + tuple(
rep(agent) for rep in rep_funcs)
agent_records = [*map(get_reports, model.schedule.agents)]
self._agent_records[model.schedule.steps] = agent_records

def add_table_row(self, table_name, row, ignore_missing=False):
Expand Down Expand Up @@ -222,15 +212,15 @@ def get_agent_vars_dataframe(self):
columns for tick and agent_id.
"""
data = defaultdict(dict)
for step, records in self._agent_records.items():
for record in records:
agent_id = record[0]
for i, value in enumerate(record[1:]):
var = list(self.agent_reporters)[i]
data[(step, agent_id)][var] = value
df = pd.DataFrame.from_dict(data, orient="index")
df.index.names = ["Step", "AgentID"]
all_records = itertools.chain.from_iterable(
self._agent_records.values())
rep_names = [rep_name for rep_name in self.agent_reporters]

df = pd.DataFrame.from_records(
data=all_records,
columns=["Step", "AgentID"] + rep_names,
)
df = df.set_index(["Step", "AgentID"])
return df

def get_table_dataframe(self, table_name):
Expand All @@ -243,8 +233,3 @@ def get_table_dataframe(self, table_name):
if table_name not in self.tables:
raise Exception("No such table.")
return pd.DataFrame(self.tables[table_name])


def get_report(rep, agent):
""" Get a report from an agent."""
return rep(agent)
3 changes: 1 addition & 2 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,9 @@ def test_agent_records(self):
data_collector = self.model.datacollector
assert len(data_collector._agent_records) == 7
for step, records in data_collector._agent_records.items():
records = list(records)
assert len(records) == 10
for values in records:
assert len(values) == 3
assert len(values) == 4

def test_table_rows(self):
'''
Expand Down

0 comments on commit 920c9ca

Please sign in to comment.