Skip to content

Commit

Permalink
Merge pull request projectmesa#823 from projectmesa/cached_neighborhood
Browse files Browse the repository at this point in the history
[PERF] Add neighborhood cache to grids and improve iter_cell_list_contents
  • Loading branch information
dmasad authored Mar 13, 2021
2 parents ed637bf + e4544a9 commit b87ab89
Showing 1 changed file with 39 additions and 34 deletions.
73 changes: 39 additions & 34 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __init__(self, width: int, height: int, torus: bool) -> None:
# Add all cells to the empties list.
self.empties = set(itertools.product(*(range(self.width), range(self.height))))

# Neighborhood Cache
self._neighborhood_cache: Dict[Any, List[Coordinate]] = dict()

@staticmethod
def default_val() -> None:
""" Default value for new cell elements. """
Expand Down Expand Up @@ -206,7 +209,7 @@ def neighbor_iter(
diagonals) or Von Neumann (only up/down/left/right).
"""
neighborhood = self.iter_neighborhood(pos, moore=moore)
neighborhood = self.get_neighborhood(pos, moore=moore)
return self.iter_cell_list_contents(neighborhood)

def iter_neighborhood(
Expand Down Expand Up @@ -236,31 +239,7 @@ def iter_neighborhood(
including the center).
"""
x, y = pos
coordinates: Set[Coordinate] = set()
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dx == 0 and dy == 0 and not include_center:
continue
# Skip coordinates that are outside manhattan distance
if not moore and abs(dx) + abs(dy) > radius:
continue
# Skip if not a torus and new coords out of bounds.
if not self.torus and (
not (0 <= dx + x < self.width) or not (0 <= dy + y < self.height)
):
continue

px, py = self.torus_adj((x + dx, y + dy))

# Skip if new coords out of bounds.
if self.out_of_bounds((px, py)):
continue

coords = (px, py)
if coords not in coordinates:
coordinates.add(coords)
yield coords
yield from self.get_neighborhood(pos, moore, include_center, radius)

def get_neighborhood(
self,
Expand Down Expand Up @@ -288,7 +267,35 @@ def get_neighborhood(
if not including the center).
"""
return list(self.iter_neighborhood(pos, moore, include_center, radius))
cache_key = (pos, moore, include_center, radius)
neighborhood = self._neighborhood_cache.get(cache_key, None)

if neighborhood is None:
coordinates: Set[Coordinate] = set()

x, y = pos
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dx == 0 and dy == 0 and not include_center:
continue
# Skip coordinates that are outside manhattan distance
if not moore and abs(dx) + abs(dy) > radius:
continue

coord = (x + dx, y + dy)

if self.out_of_bounds(coord):
# Skip if not a torus and new coords out of bounds.
if not self.torus:
continue
coord = self.torus_adj(coord)

coordinates.add(coord)

neighborhood = sorted(coordinates)
self._neighborhood_cache[cache_key] = neighborhood

return neighborhood

def iter_neighbors(
self,
Expand All @@ -314,9 +321,8 @@ def iter_neighbors(
An iterator of non-None objects in the given neighborhood;
at most 9 if Moore, 5 if Von-Neumann
(8 and 4 if not including the center).
"""
neighborhood = self.iter_neighborhood(pos, moore, include_center, radius)
neighborhood = self.get_neighborhood(pos, moore, include_center, radius)
return self.iter_cell_list_contents(neighborhood)

def get_neighbors(
Expand All @@ -325,8 +331,8 @@ def get_neighbors(
moore: bool,
include_center: bool = False,
radius: int = 1,
) -> List[Coordinate]:
"""Return a list of neighbors to a certain point.
) -> List[GridContent]:
""" Return a list of neighbors to a certain point.
Args:
pos: Coordinate tuple for the neighborhood to get.
Expand Down Expand Up @@ -354,8 +360,7 @@ def torus_adj(self, pos: Coordinate) -> Coordinate:
elif not self.torus:
raise Exception("Point out of bounds, and space non-toroidal.")
else:
x, y = pos[0] % self.width, pos[1] % self.height
return x, y
return pos[0] % self.width, pos[1] % self.height

def out_of_bounds(self, pos: Coordinate) -> bool:
"""
Expand All @@ -377,7 +382,7 @@ def iter_cell_list_contents(
An iterator of the contents of the cells identified in cell_list
"""
return (self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y)))
return filter(None, (self.grid[x][y] for x, y in cell_list))

@accept_tuple_argument
def get_cell_list_contents(
Expand Down

0 comments on commit b87ab89

Please sign in to comment.