Skip to content

Commit

Permalink
Added an option to count detection upon the center point of the bound…
Browse files Browse the repository at this point in the history
…ing box crossing the line counter.

Added an option to LineCounter class specifying the condition which determines whether a detection has crossed the line counter or not.

Additionally made it so that the line counters check whether if the corners (or optionally the center point) of a detection's bounding box are in the line counter's coordinate ranges. This way, line counters count only the detections that have precisely crossed the bounds that are drawn without failing and counting targets that have crossed invisible extensions of the lines.
  • Loading branch information
revtheundead committed Jan 16, 2024
1 parent 3b53124 commit 0e1f422
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 28 deletions.
130 changes: 102 additions & 28 deletions supervision/detection/line_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,48 @@ class LineZone:
to outside.
"""

def __init__(self, start: Point, end: Point):
def __init__(self, start: Point, end: Point, count_condition="whole_crossed"):
"""
Args:
start (Point): The starting point of the line.
end (Point): The ending point of the line.
count_condition (str): The condition which determines
how detections are counted as having crossed the line
counter. Can either be "whole_crossed" or "center_point_crossed".
If condition is set to "whole_crossed", trigger() determines
whether if the whole bounding box of the detection has crossed
the line or not. This is the default behaviour.
If condition is set to "center_point_crossed", trigger() determines
whether if the center point of the detection's bounding box has
crossed the line or not.
"""
self.vector = Vector(start=start, end=end)
self.tracker_state: Dict[str, bool] = {}
self.in_count: int = 0
self.out_count: int = 0
self.count_condition = count_condition
if count_condition not in ["whole_crossed", "center_point_crossed"]:
raise ValueError("Argument count_condition must be 'whole_crossed' or 'center_point_crossed'")

def is_point_in_line_range(self, point: Point) -> bool:
"""
Check if the given point is within the line's x and y range.
This should be used with trigger() to determine points that are
precisely within the range of the line counter's start and end points.
Args:
point (Point): The point to check
"""
line_min_x, line_max_x = min(self.vector.start.x, self.vector.end.x), max(self.vector.start.x, self.vector.end.x)
line_min_y, line_max_y = min(self.vector.start.y, self.vector.end.y), max(self.vector.start.y, self.vector.end.y)

within_line_range_x = line_min_x != line_max_x and line_min_x <= point.x <= line_max_x
within_line_range_y = line_min_y != line_max_y and line_min_y <= point.y <= line_max_y

return (within_line_range_x or line_min_x == line_max_x) and \
(within_line_range_y or line_min_y == line_max_y)

def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
"""
Expand All @@ -54,41 +86,83 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
crossed_in = np.full(len(detections), False)
crossed_out = np.full(len(detections), False)

for i, (xyxy, _, confidence, class_id, tracker_id) in enumerate(detections):
if tracker_id is None:
continue
if self.count_condition == "whole_crossed":
for i, (xyxy, _, confidence, class_id, tracker_id) in enumerate(detections):
if tracker_id is None:
continue

x1, y1, x2, y2 = xyxy
anchors = [
Point(x=x1, y=y1),
Point(x=x1, y=y2),
Point(x=x2, y=y1),
Point(x=x2, y=y2),
]
triggers = [self.vector.is_in(point=anchor) for anchor in anchors]
x1, y1, x2, y2 = xyxy

if len(set(triggers)) == 2:
continue
anchors = [
Point(x=x1, y=y1),
Point(x=x1, y=y2),
Point(x=x2, y=y1),
Point(x=x2, y=y2),
]

tracker_state = triggers[0]
triggers = [(self.vector.cross_product(point=anchor) < 0) for anchor in anchors]

if tracker_id not in self.tracker_state:
self.tracker_state[tracker_id] = tracker_state
continue
if len(set(triggers)) == 2:
continue

if self.tracker_state.get(tracker_id) == tracker_state:
continue
tracker_state = triggers[0]

self.tracker_state[tracker_id] = tracker_state
if tracker_state:
self.in_count += 1
crossed_in[i] = True
else:
self.out_count += 1
crossed_out[i] = True
if tracker_id not in self.tracker_state:
self.tracker_state[tracker_id] = tracker_state
continue

return crossed_in, crossed_out
if self.tracker_state.get(tracker_id) == tracker_state:
continue

self.tracker_state[tracker_id] = tracker_state

all_anchors_in_range = True
for anchor in anchors:
if not self.is_point_in_line_range(anchor):
all_anchors_in_range = False
break

if not all_anchors_in_range:
continue

if tracker_state:
self.in_count += 1
crossed_in[i] = True
else:
self.out_count += 1
crossed_out[i] = True

return self.in_count, self.out_count

elif self.count_condition == "center_point_crossed":
for i, (xyxy, _, confidence, class_id, tracker_id) in enumerate(detections):
if tracker_id is None:
continue

x1, y1, x2, y2 = xyxy

# Calculate the center point of the box
center_point = Point(x=(x1 + x2) / 2, y=(y1 + y2) / 2)

current_state = self.vector.cross_product(center_point)

if tracker_id not in self.tracker_state:
self.tracker_state[tracker_id] = current_state
continue

previous_state = self.tracker_state[tracker_id]

# Update the tracker state and check for crossing
if previous_state * current_state < 0 and self.is_point_in_line_range(center_point):
self.tracker_state[tracker_id] = current_state
if current_state > 0:
self.in_count += 1
crossed_in[i] = True
elif current_state < 0:
self.out_count += 1
crossed_out[i] = True

return self.in_count, self.out_count

class LineZoneAnnotator:
def __init__(
Expand Down
9 changes: 9 additions & 0 deletions supervision/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def is_in(self, point: Point) -> bool:
) * (v2.end.x - v2.start.x)
return cross_product < 0

def cross_product(self, point: Point) -> int:
"""
Determine on which side of the vector a point lies.
Returns a positive number if on one side, negative if on the other, and 0 if on the line.
"""
cross_product = (self.end.x - self.start.x) * (point.y - self.start.y) - \
(self.end.y - self.start.y) * (point.x - self.start.x)
return cross_product


@dataclass
class Rect:
Expand Down

0 comments on commit 0e1f422

Please sign in to comment.