Skip to content

Commit

Permalink
unit tests and ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel authored and tpike3 committed Jan 3, 2024
1 parent 4a5961e commit f05454f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# mypy
from typing import Any

from mesa.agent import AgentSet, Agent
from mesa.agent import Agent, AgentSet
from mesa.datacollection import DataCollector


Expand Down
24 changes: 12 additions & 12 deletions mesa/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import weakref

# mypy
from typing import Union, Iterable
from typing import Iterable, Union

from mesa.agent import Agent, AgentSet
from mesa.model import Model
Expand Down Expand Up @@ -60,7 +60,7 @@ class BaseScheduler:
- agents (property): Returns a list of all agent instances.
"""

def __init__(self, model: Model, agents: Iterable[Agent] = None) -> None:
def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None:
"""Create a new, empty BaseScheduler."""
self.model = model
self.steps = 0
Expand Down Expand Up @@ -226,7 +226,7 @@ class StagedActivation(BaseScheduler):
def __init__(
self,
model: Model,
agents: Iterable[Agent] = None,
agents: Iterable[Agent] | None = None,
stage_list: list[str] | None = None,
shuffle: bool = False,
shuffle_between_stages: bool = False,
Expand Down Expand Up @@ -294,7 +294,7 @@ class RandomActivationByType(BaseScheduler):
- get_type_count: Returns the count of agents of a specific type.
"""

def __init__(self, model: Model, agents: Iterable[Agent] = None) -> None:
def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None:
super().__init__(model, agents)

# can't be a defaultdict because we need to pass model to AgentSet
Expand All @@ -321,14 +321,14 @@ def add(self, agent: Agent) -> None:
except KeyError:
self.agents_by_type[type(agent)] = AgentSet([agent], self.model)

def remove(self, agent: Agent) -> None:
"""
Remove all instances of a given agent from the schedule.
"""
super().remove(agent)
# redundant because of weakrefs. super call only done because of warning
# agent_class: type[Agent] = type(agent)
# del self.agents_by_type[agent_class][agent.unique_id]
# def remove(self, agent: Agent) -> None:
# """
# Remove all instances of a given agent from the schedule.
# """
# super().remove(agent)
# # redundant because of weakrefs. super call only done because of warning
# # agent_class: type[Agent] = type(agent)
# # del self.agents_by_type[agent_class][agent.unique_id]

def step(self, shuffle_types: bool = True, shuffle_agents: bool = True) -> None:
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ class TestRandomActivation(TestCase):
Test the random activation.
"""

def test_init(self):
model = Model()
agents = [MockAgent(model.next_id(), model) for _ in range(10)]

scheduler = RandomActivation(model, agents)
assert all(agent in scheduler.agents for agent in agents)


def test_random_activation_step_shuffles(self):
"""
Test the random activation step
Expand Down Expand Up @@ -206,6 +214,18 @@ def test_intrastep_remove(self):
model.step()
assert len(model.log) == 1

def test_get_agent_keys(self):
model = MockModel(activation=RANDOM)

keys = model.schedule.get_agent_keys()
agent_ids = [agent.unique_id for agent in model.agents]
assert all(entry_i == entry_j for entry_i, entry_j in zip(keys, agent_ids))

keys = model.schedule.get_agent_keys(shuffle=True)
agent_ids = {agent.unique_id for agent in model.agents}
assert all(entry in agent_ids for entry in keys)



class TestSimultaneousActivation(TestCase):
"""
Expand Down Expand Up @@ -263,6 +283,17 @@ def test_random_activation_step_steps_each_agent(self):
# one step for each of 2 agents
assert all(x == 1 for x in agent_steps)

def test_random_activation_counts(self):
"""
Test the random activation by type step causes each agent to step
"""

model = MockModel(activation=RANDOM_BY_TYPE)

agent_types = model.agent_types
for agent_type in agent_types:
assert model.schedule.get_type_count(agent_type) == len(model.get_agents_of_type(agent_type))

# def test_add_non_unique_ids(self):
# """
# Test that adding agent with duplicate ids result in an error.
Expand Down Expand Up @@ -339,6 +370,11 @@ def test_invalid_event_time(self):
with self.assertRaises(ValueError):
self.scheduler.schedule_event(-1, self.agent1)

def test_invalid_aget_time(self):
with self.assertRaises(ValueError):
agent3 = MockAgent(3, self.model)
self.scheduler.schedule_event(2, agent3)

def test_immediate_event_execution(self):
# Current time of the scheduler
current_time = self.scheduler.time
Expand Down

0 comments on commit f05454f

Please sign in to comment.