Skip to content

Commit

Permalink
update tinypose act demo (PaddlePaddle#1227)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Jul 1, 2022
1 parent b13ff64 commit fc90903
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 459 deletions.
23 changes: 13 additions & 10 deletions demo/auto_compression/detection/configs/tinypose_qat_dis.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
Global:
arch: 'keypoint'
reader_config: configs/tinypose_reader.yml
input_list: ['image']
Evaluation: False
Evaluation: True
model_dir: ./tinypose_128x96
model_filename: model.pdmodel
params_filename: model.pdiparams
Expand All @@ -13,19 +14,21 @@ Distillation:
- conv2d_441.tmp_0

Quantization:
activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'abs_max'
use_pact: true
activation_quantize_type: 'moving_average_abs_max'
weight_quantize_type: 'channel_wise_abs_max' # 'abs_max' is layer wise quant
quantize_op_types:
- conv2d
- depthwise_conv2d

TrainConfig:
epochs: 1
train_iter: 30000
eval_iter: 1000
learning_rate: 0.0001
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.015
T_max: 30000
optimizer_builder:
optimizer:
type: SGD
weight_decay: 4.0e-05
#origin_metric: 0.291

optimizer:
type: Momentum
weight_decay: 0.00002
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ EvalReader:
std: *global_std
is_scale: true
- Permute: {}
batch_size: 4
batch_size: 16
23 changes: 16 additions & 7 deletions demo/auto_compression/detection/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from keypoint_utils import keypoint_post_process


def argsparser():
Expand Down Expand Up @@ -99,12 +100,16 @@ def eval():
fetch_list=fetch_targets,
return_numpy=False)
res = {}
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
if 'arch' in global_config and global_config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe, val_program,
fetch_targets, outs)
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
Expand Down Expand Up @@ -135,6 +140,10 @@ def main():
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
Expand Down
235 changes: 36 additions & 199 deletions demo/auto_compression/detection/keypoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import json
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from scipy.io import loadmat, savemat
import cv2
import copy
from paddleslim.common import get_logger

logger = get_logger(__name__, level=logging.INFO)

__all__ = ['keypoint_post_process']

def get_affine_mat_kernel(h, w, s, inv=False):
if w < h:
w_ = s
h_ = int(np.ceil((s / w * h) / 64.) * 64)
scale_w = w
scale_h = h_ / w_ * w

else:
h_ = s
w_ = int(np.ceil((s / h * w) / 64.) * 64)
scale_h = h
scale_w = w_ / h_ * h
def flip_back(output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'

center = np.array([np.round(w / 2.), np.round(h / 2.)])
output_flipped = output_flipped[:, :, :, ::-1]

size_resized = (w_, h_)
trans = get_affine_transform(
center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv)
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp

return trans, size_resized
return output_flipped


def get_affine_transform(center,
Expand Down Expand Up @@ -101,37 +92,6 @@ def get_affine_transform(center,
return trans


def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
np.sin(theta) + 0.5 * size_target[0])
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
np.cos(theta) + 0.5 * size_target[1])
return matrix


def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
Expand Down Expand Up @@ -170,29 +130,6 @@ def rotate_point(pt, angle_rad):
return rotated_pt


def transpred(kpts, h, w, s):
trans, _ = get_affine_mat_kernel(h, w, s, inv=True)

return warp_affine_joints(kpts[..., :2].copy(), trans)


def warp_affine_joints(joints, mat):
"""Apply affine transformation defined by the transform matrix on the
joints.
Args:
joints (np.ndarray[..., 2]): Origin coordinate of joints.
mat (np.ndarray[3, 2]): The affine matrix.
Returns:
matrix (np.ndarray[..., 2]): Result coordinate of joints.
"""
joints = np.array(joints)
shape = joints.shape
joints = joints.reshape(-1, 2)
return np.dot(np.concatenate(
(joints, joints[:, 0:1] * 0 + 1), axis=1),
mat.T).reshape(shape)


def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
Expand All @@ -207,130 +144,6 @@ def transform_preds(coords, center, scale, output_size):
return target_coords


def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
if not isinstance(sigmas, np.ndarray):
sigmas = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07,
.87, .87, .89, .89
]) / 10.0
vars = (sigmas * 2)**2
xg = g[0::3]
yg = g[1::3]
vg = g[2::3]
ious = np.zeros((d.shape[0]))
for n_d in range(0, d.shape[0]):
xd = d[n_d, 0::3]
yd = d[n_d, 1::3]
vd = d[n_d, 2::3]
dx = xd - xg
dy = yd - yg
e = (dx**2 + dy**2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
return ious


def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh
Args:
kpts_db (list): The predicted keypoints within the image
thresh (float): The threshold to select the boxes
sigmas (np.array): The variance to calculate the oks iou
Default: None
in_vis_thre (float): The threshold to select the high confidence boxes
Default: None
Return:
keep (list): indexes to keep
"""

if len(kpts_db) == 0:
return []

scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
kpts = np.array(
[kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])

order = scores.argsort()[::-1]

keep = []
while order.size > 0:
i = order[0]
keep.append(i)

oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
sigmas, in_vis_thre)

inds = np.where(oks_ovr <= thresh)[0]
order = order[inds + 1]

return keep


def rescore(overlap, scores, thresh, type='gaussian'):
assert overlap.shape[0] == scores.shape[0]
if type == 'linear':
inds = np.where(overlap >= thresh)[0]
scores[inds] = scores[inds] * (1 - overlap[inds])
else:
scores = scores * np.exp(-overlap**2 / thresh)

return scores


def soft_oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh
Args:
kpts_db (list): The predicted keypoints within the image
thresh (float): The threshold to select the boxes
sigmas (np.array): The variance to calculate the oks iou
Default: None
in_vis_thre (float): The threshold to select the high confidence boxes
Default: None
Return:
keep (list): indexes to keep
"""

if len(kpts_db) == 0:
return []

scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
kpts = np.array(
[kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])

order = scores.argsort()[::-1]
scores = scores[order]

# max_dets = order.size
max_dets = 20
keep = np.zeros(max_dets, dtype=np.intp)
keep_cnt = 0
while order.size > 0 and keep_cnt < max_dets:
i = order[0]

oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
sigmas, in_vis_thre)

order = order[1:]
scores = rescore(oks_ovr, scores[1:], thresh)

tmp = scores.argsort()[::-1]
order = order[tmp]
scores = scores[tmp]

keep[keep_cnt] = i
keep_cnt += 1

keep = keep[:keep_cnt]

return keep


class HRNetPostProcess(object):
def __init__(self, use_dark=True):
self.use_dark = use_dark
Expand Down Expand Up @@ -468,3 +281,27 @@ def __call__(self, output, center, scale):
maxvals, axis=1)
]]
return outputs


def keypoint_post_process(data, data_input, exe, val_program, fetch_targets,
outs):
data_input['image'] = np.flip(data_input['image'], [3])
output_flipped = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)

output_flipped = np.array(output_flipped[0])
flip_perm = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14],
[15, 16]]
output_flipped = flip_back(output_flipped, flip_perm)
output_flipped[:, :, :, 1:] = copy.copy(output_flipped)[:, :, :, 0:-1]
hrnet_outputs = (np.array(outs[0]) + output_flipped) * 0.5
imshape = (
np.array(data['im_shape']))[:, ::-1] if 'im_shape' in data else None
center = np.array(data['center']) if 'center' in data else np.round(
imshape / 2.)
scale = np.array(data['scale']) if 'scale' in data else imshape / 200.
post_process = HRNetPostProcess()
outputs = post_process(hrnet_outputs, center, scale)
return {'keypoint': outputs}
Loading

0 comments on commit fc90903

Please sign in to comment.