Skip to content

Commit

Permalink
Merge pull request projectmesa#779 from rht/mypy
Browse files Browse the repository at this point in the history
space: Add type annotation to Grid class
  • Loading branch information
dmasad authored Feb 2, 2020
2 parents cf2f409 + ad149e6 commit 4cfd991
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 38 deletions.
1 change: 1 addition & 0 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, unique_id, model):
""" Create a new agent. """
self.unique_id = unique_id
self.model = model
self.pos = None

def step(self):
""" A single step of the agent. """
Expand Down
83 changes: 45 additions & 38 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

import numpy as np

from typing import Iterable, Iterator, List, Optional, Set, Tuple, Union
from .agent import Agent
Coordinate = Tuple[int, int]
GridContent = Union[Optional[Agent], Set[Agent]]


def accept_tuple_argument(wrapped_function):
""" Decorator to allow grid methods that take a list of (x, y) coord tuples
Expand Down Expand Up @@ -70,7 +75,7 @@ class Grid:
"""

def __init__(self, width, height, torus):
def __init__(self, width: int, height: int, torus: bool) -> None:
""" Create a new grid.
Args:
Expand All @@ -82,10 +87,10 @@ def __init__(self, width, height, torus):
self.width = width
self.torus = torus

self.grid = []
self.grid = [] # type: List[List[GridContent]]

for x in range(self.width):
col = []
col = [] # type: List[GridContent]
for y in range(self.height):
col.append(self.default_val())
self.grid.append(col)
Expand All @@ -95,25 +100,27 @@ def __init__(self, width, height, torus):
*(range(self.width), range(self.height))))

@staticmethod
def default_val():
def default_val() -> None:
""" Default value for new cell elements. """
return None

def __getitem__(self, index):
def __getitem__(self, index: int) -> List[GridContent]:
return self.grid[index]

def __iter__(self):
# create an iterator that chains the
# rows of grid together as if one list:
def __iter__(self) -> Iterator[GridContent]:
"""
create an iterator that chains the
rows of grid together as if one list:
"""
return itertools.chain(*self.grid)

def coord_iter(self):
def coord_iter(self) -> Iterator[Tuple[GridContent, int, int]]:
""" An iterator that returns coordinates as well as cell contents. """
for row in range(self.width):
for col in range(self.height):
yield self.grid[row][col], row, col # agent, x, y

def neighbor_iter(self, pos, moore=True):
def neighbor_iter(self, pos: Coordinate, moore: bool = True) -> Iterator[GridContent]:
""" Iterate over position neighbors.
Args:
Expand All @@ -125,8 +132,8 @@ def neighbor_iter(self, pos, moore=True):
neighborhood = self.iter_neighborhood(pos, moore=moore)
return self.iter_cell_list_contents(neighborhood)

def iter_neighborhood(self, pos, moore,
include_center=False, radius=1):
def iter_neighborhood(self, pos: Coordinate, moore: bool,
include_center: bool = False, radius: int = 1) -> Iterator[Coordinate]:
""" Return an iterator over cell coordinates that are in the
neighborhood of a certain point.
Expand All @@ -148,7 +155,7 @@ def iter_neighborhood(self, pos, moore,
"""
x, y = pos
coordinates = set()
coordinates = set() # type: Set[Coordinate]
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dx == 0 and dy == 0 and not include_center:
Expand All @@ -171,8 +178,8 @@ def iter_neighborhood(self, pos, moore,
coordinates.add(coords)
yield coords

def get_neighborhood(self, pos, moore,
include_center=False, radius=1):
def get_neighborhood(self, pos: Coordinate, moore: bool,
include_center: bool = False, radius: int = 1) -> List[Coordinate]:
""" Return a list of cells that are in the neighborhood of a
certain point.
Expand All @@ -194,8 +201,8 @@ def get_neighborhood(self, pos, moore,
"""
return list(self.iter_neighborhood(pos, moore, include_center, radius))

def iter_neighbors(self, pos, moore,
include_center=False, radius=1):
def iter_neighbors(self, pos: Coordinate, moore: bool,
include_center: bool = False, radius: int = 1) -> Iterator[GridContent]:
""" Return an iterator over neighbors to a certain point.
Args:
Expand All @@ -219,8 +226,8 @@ def iter_neighbors(self, pos, moore,
pos, moore, include_center, radius)
return self.iter_cell_list_contents(neighborhood)

def get_neighbors(self, pos, moore,
include_center=False, radius=1):
def get_neighbors(self, pos: Coordinate, moore: bool,
include_center: bool = False, radius: int = 1) -> List[Coordinate]:
""" Return a list of neighbors to a certain point.
Args:
Expand All @@ -243,7 +250,7 @@ def get_neighbors(self, pos, moore,
return list(self.iter_neighbors(
pos, moore, include_center, radius))

def torus_adj(self, pos):
def torus_adj(self, pos: Coordinate) -> Coordinate:
""" Convert coordinate, handling torus looping. """
if not self.out_of_bounds(pos):
return pos
Expand All @@ -253,7 +260,7 @@ def torus_adj(self, pos):
x, y = pos[0] % self.width, pos[1] % self.height
return x, y

def out_of_bounds(self, pos):
def out_of_bounds(self, pos: Coordinate) -> bool:
"""
Determines whether position is off the grid, returns the out of
bounds coordinate.
Expand All @@ -262,7 +269,7 @@ def out_of_bounds(self, pos):
return x < 0 or x >= self.width or y < 0 or y >= self.height

@accept_tuple_argument
def iter_cell_list_contents(self, cell_list):
def iter_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> Iterator[GridContent]:
"""
Args:
cell_list: Array-like of (x, y) tuples, or single tuple.
Expand All @@ -275,7 +282,7 @@ def iter_cell_list_contents(self, cell_list):
self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y)))

@accept_tuple_argument
def get_cell_list_contents(self, cell_list):
def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> List[GridContent]:
"""
Args:
cell_list: Array-like of (x, y) tuples, or single tuple.
Expand All @@ -286,7 +293,7 @@ def get_cell_list_contents(self, cell_list):
"""
return list(self.iter_cell_list_contents(cell_list))

def move_agent(self, agent, pos):
def move_agent(self, agent: Agent, pos: Coordinate) -> None:
"""
Move an agent from its current position to a new position.
Expand All @@ -301,35 +308,35 @@ def move_agent(self, agent, pos):
self._place_agent(pos, agent)
agent.pos = pos

def place_agent(self, agent, pos):
def place_agent(self, agent: Agent, pos: Coordinate) -> None:
""" Position an agent on the grid, and set its pos variable. """
self._place_agent(pos, agent)
agent.pos = pos

def _place_agent(self, pos, agent):
def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
""" Place the agent at the correct location. """
x, y = pos
self.grid[x][y] = agent
self.empties.discard(pos)

def remove_agent(self, agent):
def remove_agent(self, agent: Agent) -> None:
""" Remove the agent from the grid and set its pos variable to None. """
pos = agent.pos
self._remove_agent(pos, agent)
agent.pos = None

def _remove_agent(self, pos, agent):
def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
""" Remove the agent from the given location. """
x, y = pos
self.grid[x][y] = None
self.empties.add(pos)

def is_cell_empty(self, pos):
def is_cell_empty(self, pos: Coordinate) -> bool:
""" Returns a bool of the contents of a cell. """
x, y = pos
return self.grid[x][y] == self.default_val()

def move_to_empty(self, agent):
def move_to_empty(self, agent: Agent) -> None:
""" Moves agent to a random empty cell, vacating agent's old cell. """
pos = agent.pos
if len(self.empties) == 0:
Expand All @@ -339,7 +346,7 @@ def move_to_empty(self, agent):
agent.pos = new_pos
self._remove_agent(pos, agent)

def find_empty(self):
def find_empty(self) -> Optional[Coordinate]:
""" Pick a random empty cell. """
from warnings import warn
import random
Expand All @@ -348,24 +355,24 @@ def find_empty(self):
"`random` instead of the model-level random-number generator. "
"Consider replacing it with having a model or agent object "
"explicitly pick one of the grid's list of empty cells."),
DeprecationWarning)
DeprecationWarning)

if self.exists_empty_cells():
pos = random.choice(sorted(self.empties))
return pos
else:
return None

def exists_empty_cells(self):
def exists_empty_cells(self) -> bool:
""" Return True if any cells empty else False. """
return len(self.empties) > 0


class SingleGrid(Grid):
""" Grid where each cell contains exactly at most one object. """
empties = set()
empties = set() # type: Set[Coordinate]

def __init__(self, width, height, torus):
def __init__(self, width: int, height: int, torus: bool) -> None:
""" Create a new single-item grid.
Args:
Expand Down Expand Up @@ -394,7 +401,7 @@ def position_agent(self, agent, x="random", y="random"):
agent.pos = coords
self._place_agent(coords, agent)

def _place_agent(self, pos, agent):
def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
if self.is_cell_empty(pos):
super()._place_agent(pos, agent)
else:
Expand Down Expand Up @@ -426,13 +433,13 @@ def default_val():
""" Default value for new cell elements. """
return set()

def _place_agent(self, pos, agent):
def _place_agent(self, pos: Coordinate, agent: Agent) -> None:
""" Place the agent at the correct location. """
x, y = pos
self.grid[x][y].add(agent)
self.empties.discard(pos)

def _remove_agent(self, pos, agent):
def _remove_agent(self, pos: Coordinate, agent: Agent) -> None:
""" Remove the agent from the given location. """
x, y = pos
self.grid[x][y].remove(agent)
Expand Down

0 comments on commit 4cfd991

Please sign in to comment.