Skip to content

Commit

Permalink
ultralytics 8.0.21 Windows, segments, YAML fixes (ultralytics#655)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: corey-nm <[email protected]>
  • Loading branch information
3 people authored Jan 26, 2023
1 parent dc9502c commit 6c44ce2
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 147 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ body:
label: Environment
description: Please specify the software and hardware you used to produce the bug.
placeholder: |
- YOLO: YOLOv8 🚀 v6.0-67-g60e42e1 torch 1.9.0+cu111 CUDA:0 (A100-SXM4-40GB, 40536MiB)
- YOLO: Ultralytics YOLOv8.0.21 🚀 Python-3.8.10 torch-1.13.1+cu117 CUDA:0 (A100-SXM-80GB, 81251MiB)
- OS: Ubuntu 20.04
- Python: 3.9.0
- Python: 3.8.10
validations:
required: false

Expand Down
13 changes: 7 additions & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,29 @@ def test_train_cls():

# Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect():
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1')
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32')


def test_val_segment():
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1')
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32')


def test_val_classify():
pass
run(f'yolo val classify model={MODEL}-cls.pt data=mnist160 imgsz=32')


# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=320 conf=0.25")
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")


def test_predict_segment():
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'}")
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")


def test_predict_classify():
pass
run(f"yolo predict segment model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")


# Export checks --------------------------------------------------------------------------------------------------------
Expand Down
8 changes: 5 additions & 3 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ def test_export_coreml():
model.export(format='coreml')


def test_export_paddle():
model = YOLO(MODEL)
model.export(format='paddle')
def test_export_paddle(enabled=False):
# Paddle protobuf requirements conflicting with onnx protobuf requirements
if enabled:
model = YOLO(MODEL)
model.export(format='paddle')


def test_all_model_yamls():
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

__version__ = "8.0.20"
__version__ = "8.0.21"

from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops
Expand Down
26 changes: 14 additions & 12 deletions ultralytics/yolo/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Dict, List, Union

from ultralytics import __version__
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, colorstr, yaml_load, yaml_print)
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT,
USER_CONFIG_DIR, IterableSimpleNamespace, colorstr, emojis, yaml_load, yaml_print)
from ultralytics.yolo.utils.checks import check_yolo

CLI_HELP_MSG = \
Expand Down Expand Up @@ -69,7 +69,7 @@ def cfg2dict(cfg):
return cfg


def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
Expand Down Expand Up @@ -214,39 +214,41 @@ def entrypoint(debug=False):
# Mode
mode = overrides.pop('mode', None)
model = overrides.pop('model', None)
if mode == 'checks':
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes:
if mode != 'checks':
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
check_yolo()
return
elif mode is None:
mode = DEFAULT_CFG_DICT['mode'] or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")

# Model
if model is None:
model = DEFAULT_CFG_DICT['model'] or 'yolov8n.pt'
model = DEFAULT_CFG.model or 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO
model = YOLO(model)
task = model.task

# Task
if mode == 'predict' and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG_DICT['source'] or ROOT / "assets" if (ROOT / "assets").exists() \
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
overrides['data'] = DEFAULT_CFG_DICT['data'] or 'mnist160' if task == 'classify' \
overrides['data'] = DEFAULT_CFG.data or 'mnist160' if task == 'classify' \
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG_DICT['format'] or 'torchscript'
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")

# Run command in python
getattr(model, mode)(verbose=True, **overrides)
getattr(model, mode)(**overrides)


# Special modes --------------------------------------------------------------------------------------------------------
Expand Down
11 changes: 7 additions & 4 deletions ultralytics/yolo/data/dataloaders/stream_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, tran
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'{st}Failed to open {s}'
if not cap.isOpened():
raise ConnectionError(f'{st}Failed to open {s}')
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
Expand Down Expand Up @@ -188,8 +189,9 @@ def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_s
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
if self.nf == 0:
raise FileNotFoundError(f'No images or videos found in {p}. '
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')

def __iter__(self):
self.count = 0
Expand Down Expand Up @@ -223,7 +225,8 @@ def __next__(self):
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}'
if im0 is None:
raise FileNotFoundError(f'Image Not Found {path}')
s = f'image {self.count}/{self.nf} {path}: '

if self.transforms:
Expand Down
10 changes: 4 additions & 6 deletions ultralytics/yolo/data/dataloaders/v5loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import psutil
import torch
import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm

from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle)
is_kaggle, yaml_load)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
Expand Down Expand Up @@ -1056,10 +1055,9 @@ def __init__(self, path='coco128.yaml', autodownload=False):
# Initialize class
zipped, data_dir, yaml_path = self._unzip(Path(path))
try:
with open(check_yaml(yaml_path), errors='ignore') as f:
data = yaml.safe_load(f) # data dict
if zipped:
data['path'] = data_dir
data = yaml_load(check_yaml(yaml_path)) # data dict
if zipped:
data['path'] = data_dir
except Exception as e:
raise Exception("error/HUB/dataset_stats/yaml_load") from e

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/yolo/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None):
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)

@smart_inference_mode()
Expand Down
30 changes: 11 additions & 19 deletions ultralytics/yolo/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(self, model='yolov8n.yaml', type="v8") -> None:
else:
raise NotImplementedError(f"'{suffix}' model loading not implemented")

def __call__(self, source=None, stream=False, verbose=False, **kwargs):
return self.predict(source, stream, verbose, **kwargs)
def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs)

def _new(self, cfg: str, verbose=True):
"""
Expand Down Expand Up @@ -118,15 +118,14 @@ def fuse(self):
self.model.fuse()

@smart_inference_mode()
def predict(self, source=None, stream=False, verbose=False, **kwargs):
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
verbose (bool): Whether to print verbose information or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Expand All @@ -143,7 +142,7 @@ def predict(self, source=None, stream=False, verbose=False, **kwargs):
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream, verbose=verbose)
return self.predictor(source=source, stream=stream)

@smart_inference_mode()
def val(self, data=None, **kwargs):
Expand Down Expand Up @@ -234,24 +233,17 @@ def names(self):
"""
return self.model.names

def add_callback(self, event: str, func):
@staticmethod
def add_callback(event: str, func):
"""
Add callback
"""
callbacks.default_callbacks[event].append(func)

@staticmethod
def _reset_ckpt_args(args):
args.pop("project", None)
args.pop("name", None)
args.pop("exist_ok", None)
args.pop("resume", None)
args.pop("batch", None)
args.pop("epochs", None)
args.pop("cache", None)
args.pop("save_json", None)
args.pop("half", None)
args.pop("v5loader", None)

# set device to '' to prevent from auto DDP usage
args["device"] = ''
for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \
'half', 'v5loader':
args.pop(arg, None)

args["device"] = '' # set device to '' to prevent auto-DDP usage
20 changes: 10 additions & 10 deletions ultralytics/yolo/engine/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None):
self.vid_path, self.vid_writer = None, None
self.annotator = None
self.data_path = None
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)

def preprocess(self, img):
Expand Down Expand Up @@ -151,19 +151,19 @@ def setup_source(self, source=None):
self.bs = bs

@smart_inference_mode()
def __call__(self, source=None, model=None, verbose=False, stream=False):
def __call__(self, source=None, model=None, stream=False):
if stream:
return self.stream_inference(source, model, verbose)
return self.stream_inference(source, model)
else:
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
return list(self.stream_inference(source, model)) # merge list of Result into one

def predict_cli(self):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference(verbose=True)
gen = self.stream_inference()
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass

def stream_inference(self, source=None, model=None, verbose=False):
def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start")

# setup model
Expand Down Expand Up @@ -201,7 +201,7 @@ def stream_inference(self, source=None, model=None, verbose=False):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p = Path(p)

if verbose or self.args.save or self.args.save_txt or self.args.show:
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, self.results, (p, im, im0))

if self.args.show:
Expand All @@ -214,11 +214,11 @@ def stream_inference(self, source=None, model=None, verbose=False):
yield from self.results

# Print time (inference-only)
if verbose:
if self.args.verbose:
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")

# Print results
if verbose and self.seen:
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *self.imgsz)}' % t)
Expand All @@ -243,7 +243,7 @@ def check_source(self, source):
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
Expand Down
3 changes: 1 addition & 2 deletions ultralytics/yolo/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None):
self.console = LOGGER
self.validator = None
self.model = None
self.callbacks = defaultdict(list)
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)

# Dirs
Expand Down Expand Up @@ -141,7 +140,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None):
self.plot_idx = [0, 1, 2]

# Callbacks
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
if RANK in {0, -1}:
callbacks.add_integration_callbacks(self)

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/yolo/engine/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001

self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks

@smart_inference_mode()
def __call__(self, trainer=None, model=None):
Expand Down
Loading

0 comments on commit 6c44ce2

Please sign in to comment.