Skip to content

Commit 2149695

Browse files
committed
fix gt_num 0D tensor
1 parent 0b5e790 commit 2149695

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

ppdet/modeling/losses/yolov5_loss.py

+36-22
Original file line numberDiff line numberDiff line change
@@ -198,18 +198,20 @@ def yolov5_loss(self, pi, t_cls, t_box, t_indices, t_anchor, balance):
198198
def forward(self, inputs, targets, anchors):
199199
yolo_losses = dict()
200200
#tcls, tbox, indices, anch = self.build_targets(inputs, targets, anchors)
201-
tcls, tbox, indices, anch = self.build_targets_paddle(inputs, targets, anchors)
202-
201+
tcls, tbox, indices, anch = self.build_targets_paddle(inputs, targets,
202+
anchors)
203+
203204
for i, (p_det, balance) in enumerate(zip(inputs, self.balance)):
204205
t_cls = tcls[i]
205206
t_box = tbox[i]
206207
t_anchor = anch[i]
207208
t_indices = indices[i]
208209

209210
bs, ch, h, w = p_det.shape
210-
pi = p_det.reshape((bs, self.na, int(ch/self.na), h, w)).transpose(
211-
(0, 1, 3, 4, 2))
212-
211+
pi = p_det.reshape(
212+
(bs, self.na, int(ch / self.na), h, w)).transpose(
213+
(0, 1, 3, 4, 2))
214+
213215
yolo_loss = self.yolov5_loss(pi, t_cls, t_box, t_indices, t_anchor,
214216
balance)
215217

@@ -222,13 +224,12 @@ def forward(self, inputs, targets, anchors):
222224
loss = 0
223225
for k, v in yolo_losses.items():
224226
loss += v
225-
227+
226228
batch_size = inputs[0].shape[0]
227229
num_gpus = targets.get('num_gpus', 8)
228230
yolo_losses['loss'] = loss * batch_size * num_gpus
229231
return yolo_losses
230232

231-
232233
def build_targets_paddle(self, outputs, targets, anchors):
233234
# targets['gt_class'] [bs, max_gt_nums, 1]
234235
# targets['gt_bbox'] [bs, max_gt_nums, 4]
@@ -239,19 +240,28 @@ def build_targets_paddle(self, outputs, targets, anchors):
239240
na = anchors.shape[1] # not len(anchors)
240241
tcls, tbox, indices, anch = [], [], [], []
241242

242-
gain = paddle.ones([7], dtype=np.float32) # normalized to gridspace gain
243-
ai = paddle.tile(paddle.arange(na,dtype=np.float32).reshape([na, 1]), [1, nt])
243+
gain = paddle.ones(
244+
[7], dtype=np.float32) # normalized to gridspace gain
245+
ai = paddle.tile(
246+
paddle.arange(
247+
na, dtype=np.float32).reshape([na, 1]), [1, nt])
244248

245249
batch_size = outputs[0].shape[0]
246250
gt_labels = []
247251
for idx in range(batch_size):
248-
gt_num = gt_nums[idx].astype("int32")
252+
gt_num = gt_nums[idx:idx + 1].astype("int32")
249253
if gt_num == 0:
250254
continue
255+
251256
gt_bbox = targets['gt_bbox'][idx][:gt_num]
252257
gt_class = targets['gt_class'][idx][:gt_num] * 1.0
253-
img_idx = paddle.repeat_interleave(paddle.to_tensor(idx),gt_num,axis=0)[None,:].astype(paddle.float32).T
254-
gt_labels.append(paddle.concat((img_idx, gt_class, gt_bbox),axis=-1))
258+
img_idx = paddle.repeat_interleave(
259+
paddle.to_tensor([idx]), gt_num,
260+
axis=0)[None, :].astype(paddle.float32).T
261+
gt_labels.append(
262+
paddle.concat(
263+
(img_idx, gt_class, gt_bbox), axis=-1))
264+
255265
if (len(gt_labels)):
256266
gt_labels = paddle.concat(gt_labels)
257267
else:
@@ -264,29 +274,33 @@ def build_targets_paddle(self, outputs, targets, anchors):
264274
for i in range(len(anchors)):
265275
anchor = anchors[i] / self.downsample_ratios[i]
266276
gain[2:6] = paddle.to_tensor(
267-
outputs[i].shape, dtype=paddle.float32)[[3, 2, 3, 2]] # xyxy gain
277+
outputs[i].shape,
278+
dtype=paddle.float32)[[3, 2, 3, 2]] # xyxy gain
268279

269280
# Match targets_labels to
270281
t = targets_labels * gain
271282
if nt:
272283
# Matches
273284
r = t[:, :, 4:6] / anchor[:, None]
274285
j = paddle.maximum(r, 1 / r).max(2) < self.anchor_t
275-
t = paddle.flatten(t,0,1)
276-
j = paddle.flatten(j.astype(paddle.int32),0,1).astype(paddle.bool)
286+
t = paddle.flatten(t, 0, 1)
287+
j = paddle.flatten(j.astype(paddle.int32), 0,
288+
1).astype(paddle.bool)
277289
t = t[j] # filter
278-
279290

280291
# Offsets
281292
gxy = t[:, 2:4] # grid xy
282293
gxi = gain[[2, 3]] - gxy # inverse
283294
j, k = ((gxy % 1 < g) & (gxy > 1)).T.astype(paddle.int64)
284295
l, m = ((gxi % 1 < g) & (gxi > 1)).T.astype(paddle.int64)
285-
j = paddle.flatten(paddle.stack((paddle.ones_like(j), j, k, l, m)),0,1).astype(paddle.bool)
286-
t = paddle.flatten(paddle.tile(t, [5, 1, 1]),0,1)
296+
j = paddle.flatten(
297+
paddle.stack((paddle.ones_like(j), j, k, l, m)), 0,
298+
1).astype(paddle.bool)
299+
t = paddle.flatten(paddle.tile(t, [5, 1, 1]), 0, 1)
287300
t = t[j]
288-
offsets = paddle.zeros_like(gxy)[None,:] + paddle.to_tensor(self.off)[:,None]
289-
offsets = paddle.flatten(offsets,0,1)[j]
301+
offsets = paddle.zeros_like(gxy)[None, :] + paddle.to_tensor(
302+
self.off)[:, None]
303+
offsets = paddle.flatten(offsets, 0, 1)[j]
290304
else:
291305
t = targets_labels[0]
292306
offsets = 0
@@ -297,14 +311,14 @@ def build_targets_paddle(self, outputs, targets, anchors):
297311
gwh = t[:, 4:6] # grid wh
298312
gij = (gxy - offsets).astype(paddle.int64)
299313
gi, gj = gij.T # grid xy indices
300-
314+
301315
# Append
302316
a = t[:, 6].astype(paddle.int64) # anchor indices
303317
gj, gi = gj.clip(0, gain[3] - 1), gi.clip(0, gain[2] - 1)
304318
indices.append(
305319
(b, a, gj.astype(paddle.int64), gi.astype(paddle.int64)))
306320
tbox.append(
307-
paddle.concat((gxy - gij, gwh), 1).astype(paddle.float32))
321+
paddle.concat((gxy - gij, gwh), 1).astype(paddle.float32))
308322
anch.append(anchor[a])
309323
tcls.append(c)
310324
return tcls, tbox, indices, anch

0 commit comments

Comments
 (0)