Skip to content

Commit

Permalink
Add torch.Tensor results image saving (ultralytics#3475)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
3 people authored Jul 3, 2023
1 parent cd0bf05 commit 586c95b
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
4 changes: 0 additions & 4 deletions ultralytics/yolo/engine/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,6 @@ def stream_inference(self, source=None, model=None):
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
self.done_warmup = True

# Checks
if self.source_type.tensor and (self.args.save or self.args.save_txt or self.args.show):
LOGGER.warning("WARNING ⚠️ 'save', 'save_txt' and 'show' arguments not enabled for torch.Tensor inference.")

self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
self.run_callbacks('on_predict_start')
for batch in self.dataset:
Expand Down
3 changes: 1 addition & 2 deletions ultralytics/yolo/engine/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ def plot(
(numpy.ndarray): A numpy array of the annotated image.
"""
if img is None and isinstance(self.orig_img, torch.Tensor):
LOGGER.warning('WARNING ⚠️ Results plotting is not supported for torch.Tensor image types.')
return
img = np.ascontiguousarray(self.orig_img[0].permute(1, 2, 0).cpu().detach().numpy()) * 255

# Deprecation warn TODO: remove in 8.2
if 'show_conf' in kwargs:
Expand Down

0 comments on commit 586c95b

Please sign in to comment.