Skip to content

Commit

Permalink
ultralytics 8.0.26 new YOLOv5u models (ultralytics#771)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Boguszewski <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2023
1 parent b83374b commit fa8811d
Show file tree
Hide file tree
Showing 23 changed files with 85 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
hooks:
- id: pyupgrade
name: Upgrade code
args: [ --py37-plus ]
args: [--py37-plus]

# - repo: https://github.com/PyCQA/isort
# rev: 5.11.4
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Image is CUDA-optimized for YOLOv8 single/multi-GPU training and inference

# Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
FROM nvcr.io/nvidia/pytorch:22.12-py3
FROM nvcr.io/nvidia/pytorch:23.01-py3

# Downloads to user config dir
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
Expand All @@ -26,7 +26,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics

# Install pip packages
RUN python -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics albumentations comet gsutil notebook 'opencv-python<4.6.0.66'
RUN pip install --no-cache ultralytics albumentations comet gsutil notebook

# Set environment variables
ENV OMP_NUM_THREADS=1
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile-arm64
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ RUN pip install --no-cache ultralytics gsutil notebook \
tensorflow-aarch64
# tensorflowjs \
# onnx onnx-simplifier onnxruntime \
# coremltools openvino-dev \
# coremltools openvino-dev>=2022.3 \

# Cleanup
ENV DEBIAN_FRONTEND teletype
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile-cpu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ COPY requirements.txt .
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache ultralytics albumentations gsutil notebook \
coremltools onnx onnx-simplifier onnxruntime tensorflow-cpu \
# openvino-dev tensorflowjs \
# openvino-dev>=2022.3 tensorflowjs \
--extra-index-url https://download.pytorch.org/whl/cpu

# Cleanup
Expand Down
25 changes: 19 additions & 6 deletions examples/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"import ultralytics\n",
"ultralytics.checks()"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -116,7 +116,7 @@
"# Run inference on an image with YOLOv8n\n",
"!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'"
],
"execution_count": 2,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -183,7 +183,7 @@
"# Validate YOLOv8n on COCO128 val\n",
"!yolo val model=yolov8n.pt data=coco128.yaml"
],
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -306,7 +306,7 @@
"# Train YOLOv8n on COCO128 for 3 epochs\n",
"!yolo train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640"
],
"execution_count": 4,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -495,7 +495,7 @@
"id": "CYIjW4igCjqD",
"outputId": "69cab2fb-cbfa-4acf-8e29-9c4fb6f4a38f"
},
"execution_count": 5,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -666,6 +666,19 @@
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Validate multiple models\n",
"for x in 'nsmlx':\n",
" !yolo val model=yolov8{x}.pt data=coco.yaml"
],
"metadata": {
"id": "Wdc6t_bfzDDk"
},
"execution_count": null,
"outputs": []
}
]
}
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ seaborn>=0.11.0
# scikit-learn==0.19.2 # CoreML quantization
# tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
# tensorflowjs>=3.9.0 # TF.js export
# openvino-dev>=2022.1 # OpenVINO export
# openvino-dev>=2022.3 # OpenVINO export

# Extras --------------------------------------
ipython # interactive notebook
Expand Down
4 changes: 2 additions & 2 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ def on_predict_batch_end(predictor):
# results -> List[batch_size]
path, _, im0s, _, _ = predictor.batch
# print('on_predict_batch_end', im0s[0].shape)
bs = [predictor.bs for i in range(0, len(path))]
bs = [predictor.bs for _ in range(len(path))]
predictor.results = zip(predictor.results, im0s, bs)

model = YOLO("yolov8n.pt")
model.add_callback("on_predict_batch_end", on_predict_batch_end)

dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
bs = dataset.bs # access predictor properties
bs = dataset.bs # noqa access predictor properties
results = model.predict(dataset, stream=True) # source already setup
for _, (result, im0, bs) in enumerate(results):
print('test_callback', im0.shape)
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

__version__ = "8.0.25"
__version__ = "8.0.26"

from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 4 additions & 3 deletions ultralytics/yolo/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def argument_error(arg):
return SyntaxError(f"'{arg}' is not a valid YOLO argument.\n{CLI_HELP_MSG}")


def entrypoint(debug=False):
def entrypoint(debug=''):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
Expand All @@ -163,7 +163,7 @@ def entrypoint(debug=False):
It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg
"""
args = ['train', 'model=yolov8n.pt', 'data=coco128.yaml', 'imgsz=32', 'epochs=1'] if debug else sys.argv[1:]
args = (debug.split(' ') if debug else sys.argv)[1:]
if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
Expand Down Expand Up @@ -275,4 +275,5 @@ def copy_default_cfg():


if __name__ == '__main__':
entrypoint(debug=True)
# entrypoint(debug='yolo predict model=yolov8n.pt')
entrypoint(debug='')
22 changes: 12 additions & 10 deletions ultralytics/yolo/data/dataloaders/stream_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import requests
import torch
from PIL import Image, ImageOps
from PIL import Image

from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, tran
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0 and (is_colab() or is_kaggle()):
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks."
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
"Try running 'source=0' in a local environment.")
cap = cv2.VideoCapture(s)
if not cap.isOpened():
Expand All @@ -61,9 +61,11 @@ def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, tran
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback

_, self.imgs[i] = cap.read() # guarantee first frame
success, self.imgs[i] = cap.read() # guarantee first frame
if not success or self.imgs[i] is None:
raise ConnectionError(f'{st}Failed to read images from {s}')
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
LOGGER.info(f"{st}Success ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
LOGGER.info('') # newline

Expand Down Expand Up @@ -221,15 +223,15 @@ def __next__(self):
self.mode = 'video'
for _ in range(self.vid_stride):
self.cap.grab()
ret_val, im0 = self.cap.retrieve()
while not ret_val:
success, im0 = self.cap.retrieve()
while not success:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self._new_video(path)
ret_val, im0 = self.cap.read()
success, im0 = self.cap.read()

self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
Expand Down Expand Up @@ -330,14 +332,14 @@ def autocast_list(source):
Merges a list of source of different types into a list of numpy arrays or PIL images
"""
files = []
for _, im in enumerate(source):
for im in source:
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
raise TypeError(f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
f"See https://docs.ultralytics.com/predict for supported source types.")

return files

Expand Down
22 changes: 9 additions & 13 deletions ultralytics/yolo/data/scripts/download_weights.sh
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
#!/bin/bash
# Ultralytics YOLO 🚀, GPL-3.0 license
# Download latest models from https://github.com/ultralytics/yolov5/releases
# Example usage: bash data/scripts/download_weights.sh
# Download latest models from https://github.com/ultralytics/assets/releases
# Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh
# parent
# └── yolov5
# ├── yolov5s.pt ← downloads here
# ├── yolov5m.pt
# └── weights
# ├── yolov8n.pt ← downloads here
# ├── yolov8s.pt
# └── ...

python - <<EOF
from utils.downloads import attempt_download
from ultralytics.yolo.utils.downloads import attempt_download_asset
p5 = list('nsmlx') # P5 models
p6 = [f'{x}6' for x in p5] # P6 models
cls = [f'{x}-cls' for x in p5] # classification models
seg = [f'{x}-seg' for x in p5] # classification models
for x in p5 + p6 + cls + seg:
attempt_download(f'weights/yolov5{x}.pt')
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg')]
for x in assets:
attempt_download_asset(f'weights/{x}')
EOF
11 changes: 8 additions & 3 deletions ultralytics/yolo/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,20 @@ def _export_onnx(self, prefix=colorstr('ONNX:')):
@try_export
def _export_openvino(self, prefix=colorstr('OpenVINO:')):
# YOLOv8 OpenVINO export
check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
check_requirements('openvino-dev>=2022.3') # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.runtime as ov # noqa
from openvino.tools import mo # noqa

LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
f_onnx = self.file.with_suffix('.onnx')
f_ov = str(Path(f) / self.file.with_suffix('.xml').name)

cmd = f"mo --input_model {f_onnx} --output_dir {f} {'--compress_to_fp16' * self.args.half}"
subprocess.run(cmd.split(), check=True, env=os.environ) # export
ov_model = mo.convert_model(f_onnx,
model_name=self.pretty_name,
framework="onnx",
compress_to_fp16=self.args.half) # export
ov.serialize(ov_model, f_ov) # save
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
return f, None

Expand Down
8 changes: 7 additions & 1 deletion ultralytics/yolo/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)

# Convert image size to list if it is an integer
imgsz = [imgsz] if isinstance(imgsz, int) else list(imgsz)
if isinstance(imgsz, int):
imgsz = [imgsz]
elif isinstance(imgsz, (list, tuple)):
imgsz = list(imgsz)
else:
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")

# Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
Expand Down
18 changes: 16 additions & 2 deletions ultralytics/yolo/utils/downloads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

import contextlib
import re
import subprocess
from itertools import repeat
from multiprocessing.pool import ThreadPool
Expand Down Expand Up @@ -118,7 +119,18 @@ def github_assets(repository, version='latest'):
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets

file = Path(str(file).strip().replace("'", ''))
# YOLOv3/5u updates
file = str(file)
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
original_file = file
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n")

file = Path(file.strip().replace("'", ''))
if file.exists():
return str(file)
elif (SETTINGS['weights_dir'] / file).exists():
Expand All @@ -136,7 +148,9 @@ def github_assets(repository, version='latest'):
return file

# GitHub assets
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
[f'yolov3{size}u.pt' for size in ('', '-spp', '-tiny')]
try:
tag, assets = github_assets(repo, release)
except Exception:
Expand Down
6 changes: 1 addition & 5 deletions ultralytics/yolo/v8/segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,11 @@ def write_results(self, idx, results, batch):
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
255 if self.args.retina_masks else im[idx])

# Segments
if self.args.save_txt:
segments = mask.segments

# Write results
for j, d in enumerate(reversed(det)):
cls, conf = d.cls.squeeze(), d.conf.squeeze()
if self.args.save_txt: # Write to file
seg = segments[j].copy()
seg = mask.segments[len(det) - j - 1].copy() # reversed mask.segments
seg = seg.reshape(-1) # (n,2) to (n*2)
line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format
with open(f'{self.txt_path}.txt', 'a') as f:
Expand Down

0 comments on commit fa8811d

Please sign in to comment.