Skip to content

Commit

Permalink
objectness thresholding is done. Todo: nms
Browse files Browse the repository at this point in the history
  • Loading branch information
v-iashin committed Feb 23, 2019
1 parent 1872084 commit 449d693
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
35 changes: 16 additions & 19 deletions face_detector/experiments/yolo/Untitled.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 55,
"metadata": {},
"outputs": [
{
Expand All @@ -126,32 +126,29 @@
"# todo nms\n",
"x = torch.randn((2, 3, 416, 416))\n",
"predictions = darknet.forward(x, 'cpu')\n",
"objectness_thres = 0.8\n",
"predictions[:, :, 4] *= 1000 # to be removed. just for testing\n",
"objectness_thres = 0.2\n",
"nms_thres = 0.4\n",
"classes = 80 # darknet.layers_list[-1][0].classes"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"objectness_mask = predictions[:, :, 4] > objectness_thres\n",
"(predictions[:, :, 4] == 0).sum()\n",
"# todo nms. check whether > < 0 in the previous line is \n",
"# expectable behaviour"
"# Non-max Suppression:\n",
"# 1. Filter out predictions with low objectness score\n",
"# 2. \n",
"# objectiveness filtering:\n",
"# replace all boxes with predicted probability lower than objectness_thres \n",
"# with zeros: calculate 0/1 mask and apply it to the prediction tensor\n",
"# Note: '>' returns Byte but '*' needs Float unsqueezed back to 3D\n",
"objectness_mask = (predictions[:, :, 4] > objectness_thres).float().unsqueeze(2)\n",
"predictions = predictions * objectness_mask\n",
"\n",
"# "
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion face_detector/experiments/yolo/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, x, device):
pwh = anchors_tens.repeat(w * w, 1).unsqueeze(0)
pwh = pwh.to(device)

# transform the predictions
# transform the predictions (center, size, objectness, class scores)
x[:, :, 0:2] = (torch.sigmoid(x[:, :, 0:2]) + cxy) * stride
x[:, :, 2:4] = (pwh * torch.exp(x[:, :, 2:4])) * stride
x[:, :, 4] = torch.sigmoid(x[:, :, 4])
Expand Down

0 comments on commit 449d693

Please sign in to comment.