Skip to content

Commit

Permalink
Better doc in run_classifier.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisyeh96 authored Aug 28, 2022
1 parent 042d457 commit 2588cab
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions classification/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class SimpleDataset(torch.utils.data.Dataset):
"""Very simple dataset."""

def __init__(self, img_files: Sequence[str],
images_dir: Optional[str] = None,
transform: Optional[Callable[[PIL.Image.Image], Any]] = None):
images_dir: str | None = None,
transform: Callable[[PIL.Image.Image], Any] | None = None):
"""Creates a SimpleDataset."""
self.img_files = img_files
self.images_dir = images_dir
Expand All @@ -68,7 +68,7 @@ def __len__(self) -> int:


def create_loader(cropped_images_dir: str,
detections_json_path: Optional[str],
detections_json_path: str | None,
img_size: int,
batch_size: int,
num_workers: int
Expand All @@ -77,7 +77,11 @@ def create_loader(cropped_images_dir: str,
Args:
cropped_images_dir: str, path to image crops
detections: optional dict, detections JSON
detections_json_path: optional str, path to detections JSON
img_size: int, resizes smallest side of image to img_size,
then center-crops to (img_size, img_size)
batch_size: int, batch size in dataloader
num_workers: int, # of workers in dataloader
"""
crop_files = []

Expand All @@ -96,7 +100,7 @@ def create_loader(cropped_images_dir: str,
js = json.load(f)
detections = {img['file']: img for img in js['images']}
detector_version = js['info']['detector']

for img_file, info_dict in tqdm(detections.items()):
if 'detections' not in info_dict or info_dict['detections'] is None:
continue
Expand Down Expand Up @@ -127,12 +131,12 @@ def create_loader(cropped_images_dir: str,
def main(model_path: str,
cropped_images_dir: str,
output_csv_path: str,
detections_json_path: Optional[str],
classifier_categories_json_path: Optional[str],
detections_json_path: str | None,
classifier_categories_json_path: str | None,
img_size: int,
batch_size: int,
num_workers: int,
device_id:int=None) -> None:
device_id: int | None = None) -> None:
"""Main function."""
# evaluating with accimage is much faster than Pillow or Pillow-SIMD
try:
Expand All @@ -155,7 +159,7 @@ def main(model_path: str,
# create model
print('Loading saved model')
model = torch.jit.load(model_path)
model, device = train_classifier.prep_device(model,device_id=device_id)
model, device = train_classifier.prep_device(model, device_id=device_id)

test_epoch(model, loader, device=device, label_names=label_names,
output_csv_path=output_csv_path)
Expand All @@ -164,7 +168,7 @@ def main(model_path: str,
def test_epoch(model: torch.nn.Module,
loader: torch.utils.data.DataLoader,
device: torch.device,
label_names: Optional[Sequence[str]],
label_names: Sequence[str] | None,
output_csv_path: str) -> None:
"""Runs for 1 epoch.
Expand Down Expand Up @@ -249,5 +253,5 @@ def _parse_args() -> argparse.Namespace:
classifier_categories_json_path=args.classifier_categories,
img_size=args.image_size,
batch_size=args.batch_size,
num_workers=args.num_workers,
num_workers=args.num_workers,
device_id=args.device)

0 comments on commit 2588cab

Please sign in to comment.