Skip to content

Commit

Permalink
add neighborhood cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince committed Feb 5, 2021
1 parent ef250a7 commit 2a917aa
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def __init__(self, width: int, height: int, torus: bool) -> None:
# Add all cells to the empties list.
self.empties = set(itertools.product(*(range(self.width), range(self.height))))

# Neighborhood Cache
self._neighborhood_cache = dict() # type: Dict[Any, List[Coordinate]]

@staticmethod
def default_val() -> None:
""" Default value for new cell elements. """
Expand Down Expand Up @@ -162,31 +165,7 @@ def iter_neighborhood(
including the center).
"""
x, y = pos
coordinates: Set[Coordinate] = set()
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dx == 0 and dy == 0 and not include_center:
continue
# Skip coordinates that are outside manhattan distance
if not moore and abs(dx) + abs(dy) > radius:
continue
# Skip if not a torus and new coords out of bounds.
if not self.torus and (
not (0 <= dx + x < self.width) or not (0 <= dy + y < self.height)
):
continue

px, py = self.torus_adj((x + dx, y + dy))

# Skip if new coords out of bounds.
if self.out_of_bounds((px, py)):
continue

coords = (px, py)
if coords not in coordinates:
coordinates.add(coords)
yield coords
yield from self.get_neighborhood(pos, moore, include_center, radius)

def get_neighborhood(
self,
Expand Down Expand Up @@ -214,7 +193,35 @@ def get_neighborhood(
if not including the center).
"""
return list(self.iter_neighborhood(pos, moore, include_center, radius))
cache_key = (pos, moore, include_center, radius)
neighborhood = self._neighborhood_cache.get(cache_key, None)

if neighborhood is None:
coordinates = set() # type: Set[Coordinate]

x, y = pos
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dx == 0 and dy == 0 and not include_center:
continue
# Skip coordinates that are outside manhattan distance
if not moore and abs(dx) + abs(dy) > radius:
continue

coord = (x + dx, y + dy)

if self.out_of_bounds(coord):
# Skip if not a torus and new coords out of bounds.
if not self.torus:
continue
coord = self.torus_adj(coord)

coordinates.add(coord)

neighborhood = sorted(coordinates)
self._neighborhood_cache[cache_key] = neighborhood

return neighborhood

def iter_neighbors(
self,
Expand Down

0 comments on commit 2a917aa

Please sign in to comment.