Skip to content

Commit

Permalink
Fix dataloader2 (ultralytics#35)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2022
1 parent 523eff9 commit c617ee1
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 138 deletions.
2 changes: 1 addition & 1 deletion ultralytics/tests/data/dataloader/hyp_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
copy_paste: 0.5 # segment copy-paste (probability)
102 changes: 57 additions & 45 deletions ultralytics/tests/data/dataloader/yolopose.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,48 +67,60 @@ def plot_keypoint(img, keypoints, color, tl):
with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
hyp = OmegaConf.load(f)

dataloader, dataset = build_dataloader(
img_path="/d/dataset/COCO/images/val2017",
img_size=640,
label_path=None,
cache=False,
hyp=hyp,
augment=False,
prefix="",
rect=False,
batch_size=4,
stride=32,
pad=0.5,
use_segments=False,
use_keypoints=True,
)

for d in dataloader:
idx = 1 # show which image inside one batch
img = d["img"][idx].numpy()
img = np.ascontiguousarray(img.transpose(1, 2, 0))
ih, iw = img.shape[:2]
# print(img.shape)
bidx = d["batch_idx"]
cls = d["cls"][bidx == idx].numpy()
bboxes = d["bboxes"][bidx == idx].numpy()
bboxes[:, [0, 2]] *= iw
bboxes[:, [1, 3]] *= ih
keypoints = d["keypoints"][bidx == idx]
keypoints[..., 0] *= iw
keypoints[..., 1] *= ih
# print(keypoints, keypoints.shape)
# print(d["im_file"])

for i, b in enumerate(bboxes):
x, y, w, h = b
x1 = x - w / 2
x2 = x + w / 2
y1 = y - h / 2
y2 = y + h / 2
c = int(cls[i][0])
# print(x1, y1, x2, y2)
plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, keypoints=keypoints[i], label=f"{c}", color=colors(c))
cv2.imshow("p", img)
if cv2.waitKey(0) == ord("q"):
break

def test(augment, rect):
dataloader, _ = build_dataloader(
img_path="/d/dataset/COCO/images/val2017",
img_size=640,
label_path=None,
cache=False,
hyp=hyp,
augment=augment,
prefix="",
rect=rect,
batch_size=4,
stride=32,
pad=0.5,
use_segments=False,
use_keypoints=True,
)

for d in dataloader:
idx = 1 # show which image inside one batch
img = d["img"][idx].numpy()
img = np.ascontiguousarray(img.transpose(1, 2, 0))
ih, iw = img.shape[:2]
# print(img.shape)
bidx = d["batch_idx"]
cls = d["cls"][bidx == idx].numpy()
bboxes = d["bboxes"][bidx == idx].numpy()
bboxes[:, [0, 2]] *= iw
bboxes[:, [1, 3]] *= ih
keypoints = d["keypoints"][bidx == idx]
keypoints[..., 0] *= iw
keypoints[..., 1] *= ih
# print(keypoints, keypoints.shape)
# print(d["im_file"])

for i, b in enumerate(bboxes):
x, y, w, h = b
x1 = x - w / 2
x2 = x + w / 2
y1 = y - h / 2
y2 = y + h / 2
c = int(cls[i][0])
# print(x1, y1, x2, y2)
plot_one_box([int(x1), int(y1), int(x2), int(y2)],
img,
keypoints=keypoints[i],
label=f"{c}",
color=colors(c))
cv2.imshow("p", img)
if cv2.waitKey(0) == ord("q"):
break


if __name__ == "__main__":
test(augment=True, rect=False)
test(augment=False, rect=True)
test(augment=False, rect=False)
126 changes: 71 additions & 55 deletions ultralytics/tests/data/dataloader/yolosegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,58 +55,74 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
hyp = OmegaConf.load(f)

dataloader, dataset = build_dataloader(
img_path="/d/dataset/COCO/coco128-seg/images",
img_size=640,
label_path=None,
cache=False,
hyp=hyp,
augment=False,
prefix="",
rect=False,
batch_size=4,
stride=32,
pad=0.5,
use_segments=True,
use_keypoints=False,
)

for d in dataloader:
idx = 1 # show which image inside one batch
img = d["img"][idx].numpy()
img = np.ascontiguousarray(img.transpose(1, 2, 0))
ih, iw = img.shape[:2]
# print(img.shape)
bidx = d["batch_idx"]
cls = d["cls"][bidx == idx].numpy()
bboxes = d["bboxes"][bidx == idx].numpy()
masks = d["masks"][idx]
print(bboxes.shape)
bboxes[:, [0, 2]] *= iw
bboxes[:, [1, 3]] *= ih
nl = len(cls)

index = torch.arange(nl).view(nl, 1, 1) + 1
masks = masks.repeat(nl, 1, 1)
# print(masks.shape, index.shape)
masks = torch.where(masks == index, 1, 0)
masks = masks.numpy().astype(np.uint8)
print(masks.shape)
# keypoints = d["keypoints"]

for i, b in enumerate(bboxes):
x, y, w, h = b
x1 = x - w / 2
x2 = x + w / 2
y1 = y - h / 2
y2 = y + h / 2
c = int(cls[i][0])
# print(x1, y1, x2, y2)
plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, label=f"{c}", color=colors(c))
mask = masks[i]
mask = cv2.resize(mask, (iw, ih))
mask = mask.astype(bool)
img[mask] = img[mask] * 0.5 + np.array(colors(c)) * 0.5
cv2.imshow("p", img)
if cv2.waitKey(0) == ord("q"):
break

def test(augment, rect):
dataloader, _ = build_dataloader(
img_path="/d/dataset/COCO/coco128-seg/images",
img_size=640,
label_path=None,
cache=False,
hyp=hyp,
augment=augment,
prefix="",
rect=rect,
batch_size=4,
stride=32,
pad=0.5,
use_segments=True,
use_keypoints=False,
)

for d in dataloader:
# info
im_file = d["im_file"]
ori_shape = d["ori_shape"]
resize_shape = d["resized_shape"]
print(ori_shape, resize_shape)
print(im_file)

# labels
idx = 1 # show which image inside one batch
img = d["img"][idx].numpy()
img = np.ascontiguousarray(img.transpose(1, 2, 0))
ih, iw = img.shape[:2]
# print(img.shape)
bidx = d["batch_idx"]
cls = d["cls"][bidx == idx].numpy()
bboxes = d["bboxes"][bidx == idx].numpy()
masks = d["masks"][idx]
print(bboxes.shape)
bboxes[:, [0, 2]] *= iw
bboxes[:, [1, 3]] *= ih
nl = len(cls)

index = torch.arange(nl).view(nl, 1, 1) + 1
masks = masks.repeat(nl, 1, 1)
# print(masks.shape, index.shape)
masks = torch.where(masks == index, 1, 0)
masks = masks.numpy().astype(np.uint8)
print(masks.shape)
# keypoints = d["keypoints"]

for i, b in enumerate(bboxes):
x, y, w, h = b
x1 = x - w / 2
x2 = x + w / 2
y1 = y - h / 2
y2 = y + h / 2
c = int(cls[i][0])
# print(x1, y1, x2, y2)
plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, label=f"{c}", color=colors(c))
mask = masks[i]
mask = cv2.resize(mask, (iw, ih))
mask = mask.astype(bool)
img[mask] = img[mask] * 0.5 + np.array(colors(c)) * 0.5
cv2.imshow("p", img)
if cv2.waitKey(0) == ord("q"):
break


if __name__ == "__main__":
test(augment=True, rect=False)
test(augment=False, rect=True)
test(augment=False, rect=False)
49 changes: 22 additions & 27 deletions ultralytics/yolo/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _cat_labels(self, mosaic_labels):
cls.append(labels["cls"])
instances.append(labels["instances"])
final_labels = {
"ori_shape": (self.img_size * 2, self.img_size * 2),
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (self.img_size * 2, self.img_size * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0)}
Expand Down Expand Up @@ -351,7 +351,7 @@ def __call__(self, labels):
"""
img = labels["img"]
cls = labels["cls"]
instances = labels["instances"]
instances = labels.pop("instances")
# make sure the coord formats are right
instances.convert_bbox(format="xyxy")
instances.denormalize(*img.shape[:2][::-1])
Expand All @@ -372,6 +372,7 @@ def __call__(self, labels):
if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M)
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
# clip
new_instances.clip(*self.size)

# filter instances
Expand All @@ -381,9 +382,9 @@ def __call__(self, labels):
box2=new_instances.bboxes.T,
area_thr=0.01 if len(segments) else 0.10)
labels["instances"] = new_instances[i]
# clip
labels["cls"] = cls[i]
labels["img"] = img
labels["resized_shape"] = img.shape[:2]
return labels

def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
Expand Down Expand Up @@ -430,7 +431,7 @@ def __init__(self, p=0.5, direction="horizontal") -> None:

def __call__(self, labels):
img = labels["img"]
instances = labels["instances"]
instances = labels.pop("instances")
instances.convert_bbox(format="xywh")
h, w = img.shape[:2]
h = 1 if instances.normalized else h
Expand All @@ -439,13 +440,11 @@ def __call__(self, labels):
# Flip up-down
if self.direction == "vertical" and random.random() < self.p:
img = np.flipud(img)
img = np.ascontiguousarray(img)
instances.flipud(h)
if self.direction == "horizontal" and random.random() < self.p:
img = np.fliplr(img)
img = np.ascontiguousarray(img)
instances.fliplr(w)
labels["img"] = img
labels["img"] = np.ascontiguousarray(img)
labels["instances"] = instances
return labels

Expand All @@ -463,7 +462,7 @@ def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=Tr
def __call__(self, labels={}, image=None):
img = image or labels["img"]
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.get("rect_shape", self.new_shape)
new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

Expand Down Expand Up @@ -495,6 +494,7 @@ def __call__(self, labels={}, image=None):

labels = self._update_labels(labels, ratio, dw, dh)
labels["img"] = img
labels["resized_shape"] = new_shape
return labels

def _update_labels(self, labels, ratio, padw, padh):
Expand All @@ -515,26 +515,21 @@ def __call__(self, labels):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
im = labels["img"]
cls = labels["cls"]
bboxes = labels["instances"].bboxes
segments = labels["instances"].segments # n, 1000, 2
keypoints = labels["instances"].keypoints
if self.p and len(segments):
n = len(segments)
instances = labels.pop("instances")
instances.convert_bbox(format="xyxy")
if self.p and len(instances.segments):
n = len(instances)
h, w, _ = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
# TODO: this implement can be parallel since segments are ndarray, also might work with Instances inside
for j in random.sample(range(n), k=round(self.p * n)):
c, b, s = cls[j], bboxes[j], segments[j]
box = w - b[2], b[1], w - b[0], b[3]
ioa = bbox_ioa(box, bboxes) # intersection over area
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
bboxes = np.concatenate((bboxes, [box]), 0)
cls = np.concatenate((cls, c[None]), axis=0)
segments = np.concatenate((segments, np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)[None]), 0)
if keypoints is not None:
keypoints = np.concatenate(
(keypoints, np.concatenate((w - keypoints[j][:, 0:1], keypoints[j][:, 1:2]), 1)), 0)
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
j = random.sample(range(n), k=round(self.p * n))
c, instance = cls[j], instances[j]
instance.fliplr(w)
ioa = bbox_ioa(instance.bboxes, instances.bboxes) # intersection over area, (N, M)
i = (ioa < 0.30).all(1) # (N, )
if i.sum():
cls = np.concatenate((cls, c[i]), axis=0)
instances = Instances.concatenate((instances, instance[i]), axis=0)
cv2.drawContours(im_new, instances.segments[j][i].astype(np.int32), -1, (255, 255, 255), cv2.FILLED)

result = cv2.bitwise_and(src1=im, src2=im_new)
result = cv2.flip(result, 1) # augment segments (flip left-right)
Expand All @@ -543,7 +538,7 @@ def __call__(self, labels):
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
labels["img"] = im
labels["cls"] = cls
labels["instances"].update(bboxes, segments, keypoints)
labels["instances"] = instances
return labels


Expand Down
Loading

0 comments on commit c617ee1

Please sign in to comment.