Skip to content

Commit

Permalink
Feature/team classification callback (AtomScott#119)
Browse files Browse the repository at this point in the history
* add TeamClassificationCallback
* add BaseVectorModel, SKLearnVectorModel
* add WIP callback notebook

Co-authored-by: IkumaUchida <[email protected]>
Co-authored-by: Atom Scott <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent 9c56887 commit 66c8033
Show file tree
Hide file tree
Showing 21 changed files with 796 additions and 1,422 deletions.
1,402 changes: 271 additions & 1,131 deletions notebooks/02_user_guide/02_dataframe_visualization.ipynb

Large diffs are not rendered by default.

153 changes: 153 additions & 0 deletions notebooks/02_user_guide/wip_09_callbacks.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ follow_imports = "skip"
strict_optional = true

[tool.pep257]
ignore = ["D100", "D101", "D105", "D200", "D202", "D411", "D213", "D413", "D406", "D407"]
ignore = ["D100", "D101", "D105", "D200", "D202", "D203", "D411", "D213", "D413", "D406", "D407"]
add_ignore = ["D102", "D103"]

[tool.pytest.ini_options]
Expand Down
49 changes: 0 additions & 49 deletions sportslabkit/callbacks.py

This file was deleted.

10 changes: 5 additions & 5 deletions sportslabkit/camera/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vidgear.gears.stabilizer import Stabilizer

from sportslabkit.logger import logger, tqdm
from sportslabkit.types.types import _pathlike
from sportslabkit.types.types import PathLike
from sportslabkit.utils import make_video


Expand Down Expand Up @@ -170,10 +170,10 @@ def calibrate_camera_fisheye(objpoints, imgpoints, dim, balance=1):


def find_intrinsic_camera_parameters(
media_path: _pathlike,
media_path: PathLike,
fps: int = 1,
scale: int = 4,
save_path: _pathlike | None = None,
save_path: PathLike | None = None,
draw_on_save: bool = False,
points_to_use: int = 50,
calibration_method: str = "zhang",
Expand Down Expand Up @@ -237,10 +237,10 @@ def find_intrinsic_camera_parameters(


def calibrate_video_from_mappings(
media_path: _pathlike,
media_path: PathLike,
mapx: NDArray,
mapy: NDArray,
save_path: _pathlike,
save_path: PathLike,
stabilize: bool = True,
):
"""
Expand Down
4 changes: 2 additions & 2 deletions sportslabkit/camera/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

from sportslabkit.camera.calibrate import find_intrinsic_camera_parameters
from sportslabkit.camera.videoreader import VideoReader
from sportslabkit.types.types import _pathlike
from sportslabkit.types.types import PathLike
from sportslabkit.utils import logger


class Camera(VideoReader):
def __init__(
self,
video_path: _pathlike,
video_path: PathLike,
threaded: bool = False,
queue_size: int = 10,
keypoint_xml: str | None = None,
Expand Down
8 changes: 4 additions & 4 deletions sportslabkit/dataframe/coordinatesdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from sportslabkit.dataframe.base import BaseSLKDataFrame
from sportslabkit.logger import logger
from sportslabkit.types.types import _pathlike
from sportslabkit.types.types import PathLike


def merge_dicts(*dicts):
Expand Down Expand Up @@ -55,7 +55,7 @@ def set_keypoints(
source_keypoints: ArrayLike | None = None,
target_keypoints: ArrayLike | None = None,
mapping: Mapping | None = None,
mapping_file: _pathlike | None = None,
mapping_file: PathLike | None = None,
) -> None:
"""Set the keypoints for the homography transformation. Make sure that
the target keypoints are the pitch coordinates. Also each keypoint must
Expand Down Expand Up @@ -235,7 +235,7 @@ def from_dict(d: dict, attributes: Iterable[str] | None = ("x", "y")):
def visualize_frame(
self,
frame_idx: int,
save_path: _pathlike | None = None,
save_path: PathLike | None = None,
ball_key: str = "ball",
home_key: str = "0",
away_key: str = "1",
Expand Down Expand Up @@ -341,7 +341,7 @@ def visualize_frame(

def visualize_frames(
self,
save_path: _pathlike,
save_path: PathLike,
ball_key: str = "ball",
home_key: str = "0",
away_key: str = "1",
Expand Down
6 changes: 3 additions & 3 deletions sportslabkit/datasets/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sportslabkit.logger import inspect, logger


_pathlike = Union[str, Path]
PathLike = Union[str, Path]
_module_path = Path(__file__).parent


Expand All @@ -31,7 +31,7 @@ def dataset_list_files(self) -> None:
def download(
self,
file_name: str | None = None,
path: _pathlike | None = _module_path,
path: PathLike | None = _module_path,
force: bool = False,
quiet: bool = False,
unzip: bool = True,
Expand All @@ -40,7 +40,7 @@ def download(
Args:
file_name (Optional[str], optional): Name of the file to download. If None, downloads all data. Defaults to None.
path (Optional[_pathlike], optional): Path to download the data to. If None, downloads to soccertrack/datasets/data. Defaults to None.
path (Optional[PathLike], optional): Path to download the data to. If None, downloads to soccertrack/datasets/data. Defaults to None.
force (bool, optional): If True, overwrites the existing file. Defaults to False.
quiet (bool, optional): If True, suppresses the output. Defaults to True.
unzip (bool, optional): If True, unzips the file. Defaults to True.
Expand Down
50 changes: 25 additions & 25 deletions sportslabkit/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..dataframe import BBoxDataFrame, CoordinatesDataFrame


_pathlike = Union[str, os.PathLike]
PathLike = Union[str, os.PathLike]


def auto_string_parser(value: str) -> Any:
Expand Down Expand Up @@ -55,7 +55,7 @@ def auto_string_parser(value: str) -> Any:
return value


def infer_metadata_from_filename(filename: _pathlike) -> Mapping[str, int]:
def infer_metadata_from_filename(filename: PathLike) -> Mapping[str, int]:
"""Try to infer metadata from filename.
Args:
Expand All @@ -82,7 +82,7 @@ def infer_metadata_from_filename(filename: _pathlike) -> Mapping[str, int]:


def load_gpsports(
filename: _pathlike,
filename: PathLike,
playerid: int | None = None,
teamid: int | None = None,
) -> CoordinatesDataFrame:
Expand Down Expand Up @@ -118,7 +118,7 @@ def load_gpsports(


def load_statsports(
filename: _pathlike,
filename: PathLike,
playerid: int | None = None,
teamid: int | None = None,
) -> CoordinatesDataFrame:
Expand Down Expand Up @@ -155,7 +155,7 @@ def load_statsports(


def load_soccertrack_coordinates(
filename: _pathlike,
filename: PathLike,
playerid: int | None = None,
teamid: int | None = None,
) -> CoordinatesDataFrame:
Expand All @@ -182,11 +182,11 @@ def load_soccertrack_coordinates(
return df


def is_soccertrack_coordinates(filename: _pathlike) -> bool:
def is_soccertrack_coordinates(filename: PathLike) -> bool:
return True


def infer_gps_format(filename: _pathlike) -> str:
def infer_gps_format(filename: PathLike) -> str:
"""Try to infer GPS format from filename.
Args:
Expand All @@ -206,14 +206,14 @@ def infer_gps_format(filename: _pathlike) -> str:

def get_gps_loader(
format: str,
) -> Callable[[_pathlike, int, int], CoordinatesDataFrame]:
) -> Callable[[PathLike, int, int], CoordinatesDataFrame]:
"""Get GPS loader function for a given format.
Args:
format (str): GPS format.
Returns:
Callable[[_pathlike, int, int], CoordinatesDataFrame]: GPS loader function.
Callable[[PathLike, int, int], CoordinatesDataFrame]: GPS loader function.
"""
format = format.lower()
if format == "gpsports":
Expand All @@ -226,7 +226,7 @@ def get_gps_loader(


def load_codf(
filename: _pathlike,
filename: PathLike,
format: str | None = None,
playerid: int | None = None,
teamid: int | None = None,
Expand Down Expand Up @@ -254,7 +254,7 @@ def load_codf(


def load_gps(
filenames: (Sequence[_pathlike,] | _pathlike),
filenames: (Sequence[PathLike,] | PathLike),
playerids: Sequence[int] | int = (),
teamids: Sequence[int] | int = (),
) -> CoordinatesDataFrame:
Expand Down Expand Up @@ -319,7 +319,7 @@ def load_gps_from_yaml(yaml_path: str) -> CoordinatesDataFrame:
return load_gps(filepaths, playerids, teamids)


def load_labelbox(filename: _pathlike) -> CoordinatesDataFrame:
def load_labelbox(filename: PathLike) -> CoordinatesDataFrame:
"""Load labelbox format file to CoordinatesDataFrame.
Args:
Expand Down Expand Up @@ -383,7 +383,7 @@ def load_labelbox(filename: _pathlike) -> CoordinatesDataFrame:
return merged_dataframe


def load_mot(filename: _pathlike) -> CoordinatesDataFrame:
def load_mot(filename: PathLike) -> CoordinatesDataFrame:
"""Load MOT format file to CoordinatesDataFrame.
Args:
Expand Down Expand Up @@ -432,12 +432,12 @@ def load_mot(filename: _pathlike) -> CoordinatesDataFrame:


def load_soccertrack_bbox(
filename: _pathlike,
filename: PathLike,
) -> pd.DataFrame:
"""Load a dataframe from a file.
Args:
filename (_pathlike): Path to load the dataframe.
filename (PathLike): Path to load the dataframe.
Returns:
df (pd.DataFrame): Dataframe loaded from the file.
"""
Expand Down Expand Up @@ -468,11 +468,11 @@ def load_soccertrack_bbox(
return df


def is_mot(filename: _pathlike) -> bool:
def is_mot(filename: PathLike) -> bool:
"""Return True if the file is MOT format.
Args:
filename(_pathlike): Path to file.
filename(PathLike): Path to file.
Returns:
is_mot(bool): True if the file is MOT format.
Expand All @@ -494,11 +494,11 @@ def is_mot(filename: _pathlike) -> bool:
return False


def infer_bbox_format(filename: _pathlike) -> str:
def infer_bbox_format(filename: PathLike) -> str:
"""Try to infer the format of a given bounding box file.
Args:
filename(_pathlike): Path to bounding box file.
filename(PathLike): Path to bounding box file.
Returns:
format(str): Inferred format of the bounding box file.
Expand All @@ -516,14 +516,14 @@ def infer_bbox_format(filename: _pathlike) -> str:

def get_bbox_loader(
format: str,
) -> Callable[[_pathlike], BBoxDataFrame]:
) -> Callable[[PathLike], BBoxDataFrame]:
"""Returns a function that loads the corresponding bbox format.
Args:
format(str): bbox format to load.
Returns:
bbox_loader(Callable[[_pathlike], BBoxDataFrame]): Function that loads the corresponding bbox format.
bbox_loader(Callable[[PathLike], BBoxDataFrame]): Function that loads the corresponding bbox format.
"""
format = format.lower()
if format == "mot":
Expand All @@ -535,11 +535,11 @@ def get_bbox_loader(
raise ValueError(f"Unknown format {format}")


def load_bbox(filename: _pathlike) -> BBoxDataFrame:
def load_bbox(filename: PathLike) -> BBoxDataFrame:
"""Load a BBoxDataFrame from a file.
Args:
filename(_pathlike): Path to bounding box file.
filename(PathLike): Path to bounding box file.
Returns:
bbox(BBoxDataFrame): BBoxDataFrame loaded from the file.
Expand All @@ -552,7 +552,7 @@ def load_bbox(filename: _pathlike) -> BBoxDataFrame:
return df


def load_df(filename: _pathlike, df_type: str = "bbox") -> BBoxDataFrame | CoordinatesDataFrame:
def load_df(filename: PathLike, df_type: str = "bbox") -> BBoxDataFrame | CoordinatesDataFrame:
"""Loads either a BBoxDataFrame or a CoordinatesDataFrame from a file.
Args:
Expand All @@ -572,7 +572,7 @@ def load_df(filename: _pathlike, df_type: str = "bbox") -> BBoxDataFrame | Coord
return df


# def load_bboxes_from_yaml(yaml_path: _pathlike) -> BBoxDataFrame:
# def load_bboxes_from_yaml(yaml_path: PathLike) -> BBoxDataFrame:
# """
# Args:
# yaml_path(str): Path to yaml file.
Expand Down
Loading

0 comments on commit 66c8033

Please sign in to comment.