Skip to content

Commit

Permalink
add decoupled head
Browse files Browse the repository at this point in the history
add decoupled head
  • Loading branch information
HuKai97 authored May 18, 2022
1 parent 3875a82 commit b03a835
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,33 @@ def parse_model(d, ch): # model_dict, input_channels(3)

return nn.Sequential(*layers), sorted(save)

class DecoupledHead(nn.Module):
def __init__(self, ch=256, nc=80, width=1.0, anchors=()):
super().__init__()
self.nc = nc # number of classes
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.merge = Conv(ch, 256 * width, 1, 1)
self.cls_convs1 = Conv(256 * width, 256 * width, 3, 1, 1)
self.cls_convs2 = Conv(256 * width, 256 * width, 3, 1, 1)
self.reg_convs1 = Conv(256 * width, 256 * width, 3, 1, 1)
self.reg_convs2 = Conv(256 * width, 256 * width, 3, 1, 1)
self.cls_preds = nn.Conv2d(256 * width, self.nc * self.na, 1)
self.reg_preds = nn.Conv2d(256 * width, 4 * self.na, 1)
self.obj_preds = nn.Conv2d(256 * width, 1 * self.na, 1)

def forward(self, x):
x = self.merge(x)
x1 = self.cls_convs1(x)
x1 = self.cls_convs2(x1)
x1 = self.cls_preds(x1)
x2 = self.reg_convs1(x)
x2 = self.reg_convs2(x2)
x21 = self.reg_preds(x2)
x22 = self.obj_preds(x2)
out = torch.cat([x21, x22, x1], 1)
return out

class Detect(nn.Module):
"""Detect模块是用来构建Detect层的,将输入feature map 通过一个卷积操作和公式计算到我们想要的shape, 为后面的计算损失或者NMS作准备"""
stride = None # strides computed during build
Expand Down Expand Up @@ -162,6 +189,8 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True):
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))
# output conv 对每个输出的feature map都要调用一次conv1x1
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)
# 调用解耦头 DecoupledHead
# self.m = nn.ModuleList(DecoupledHead(x, nc, 1, anchors) for x in ch)
# use in-place ops (e.g. slice assignment) 一般都是True 默认不使用AWS Inferentia加速
self.inplace = inplace

Expand Down Expand Up @@ -270,6 +299,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None):
# 检查anchor顺序与stride顺序是否一致
check_anchor_order(m)
self.stride = m.stride
# 调用解耦头这里要注释掉
self._initialize_biases() # only run once 初始化偏置
# logger.info('Strides: %s' % m.stride.tolist())

Expand Down

0 comments on commit b03a835

Please sign in to comment.