Skip to content

Commit

Permalink
rename --distill to --distill_range
Browse files Browse the repository at this point in the history
  • Loading branch information
jakc4103 committed Feb 5, 2020
1 parent 8ba47e5 commit 87ed0fb
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
18 changes: 9 additions & 9 deletions main_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_argument():
parser = argparse.ArgumentParser()
parser.add_argument("--quantize", action='store_true')
parser.add_argument("--equalize", action='store_true')
parser.add_argument("--distill", action='store_true')
parser.add_argument("--distill_range", action='store_true')
parser.add_argument("--correction", action='store_true')
parser.add_argument("--absorption", action='store_true')
parser.add_argument("--relu", action='store_true')
Expand Down Expand Up @@ -83,7 +83,7 @@ def main():
model = mobilenet_v2('modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar')
model.eval()

if args.distill:
if args.distill_range:
import copy
# define FP32 model
model_original = copy.deepcopy(model)
Expand Down Expand Up @@ -116,7 +116,7 @@ def main():
transformer = TorchTransformer()
module_dict = {}
if args.quantize:
if args.distill:
if args.distill_range:
module_dict[1] = [(nn.Conv2d, QConv2d), (nn.Linear, QLinear)]
elif args.trainable:
module_dict[1] = [(nn.Conv2d, QuantConv2d), (nn.Linear, QuantLinear)]
Expand All @@ -134,7 +134,7 @@ def main():
graph = transformer.log.getGraph()
bottoms = transformer.log.getBottoms()
if args.quantize:
if args.distill:
if args.distill_range:
targ_layer = [QConv2d, QLinear]
elif args.trainable:
targ_layer = [QuantConv2d, QuantLinear]
Expand All @@ -149,7 +149,7 @@ def main():
model = merge_batchnorm(model, graph, bottoms, targ_layer)

#create relations
if args.equalize or args.distill:
if args.equalize or args.distill_range:
res = create_relation(graph, bottoms, targ_layer, delete_single=False)
if args.equalize:
cross_layer_equalization(graph, res, targ_layer, visualize_state=False, converge_thres=2e-7)
Expand Down Expand Up @@ -177,10 +177,10 @@ def main():
bias_correction(graph, bottoms, targ_layer, bits_weight=args.bits_weight)

if args.quantize:
if not args.trainable and not args.distill:
if not args.trainable and not args.distill_range:
graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer)

if args.distill:
if args.distill_range:
set_update_stat(model, [QuantMeasure], True)
model = update_quant_range(model.cuda(), data_distill, graph, bottoms)
set_update_stat(model, [QuantMeasure], False)
Expand All @@ -204,8 +204,8 @@ def main():
restore_op()
if args.log:
with open("cls_result.txt", 'a+') as ww:
ww.write("resnet: {}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill: {}\n".format(
args.resnet, args.quantize, args.relu, args.equalize, args.absorption, args.correction, args.clip_weight, args.distill
ww.write("resnet: {}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n".format(
args.resnet, args.quantize, args.relu, args.equalize, args.absorption, args.correction, args.clip_weight, args.distill_range
))
ww.write("Acc: {}\n\n".format(acc))

Expand Down
16 changes: 8 additions & 8 deletions main_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_argument():
parser.add_argument("--equalize", action='store_true')
parser.add_argument("--correction", action='store_true')
parser.add_argument("--absorption", action='store_true')
parser.add_argument("--distill", action='store_true')
parser.add_argument("--distill_range", action='store_true')
parser.add_argument("--log", action='store_true')
parser.add_argument("--relu", action='store_true')
parser.add_argument("--clip_weight", action='store_true')
Expand Down Expand Up @@ -107,7 +107,7 @@ def main():
state_dict = torch.load('modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict']
model.load_state_dict(state_dict)
model.eval()
if args.distill:
if args.distill_range:
import copy
# define FP32 model
model_original = copy.deepcopy(model)
Expand All @@ -124,7 +124,7 @@ def main():

module_dict = {}
if args.quantize:
if args.distill:
if args.distill_range:
module_dict[1] = [(nn.Conv2d, QConv2d)]
elif args.trainable:
module_dict[1] = [(nn.Conv2d, QuantConv2d)]
Expand All @@ -142,7 +142,7 @@ def main():
bottoms = transformer.log.getBottoms()

if args.quantize:
if args.distill:
if args.distill_range:
targ_layer = [QConv2d]
elif args.trainable:
targ_layer = [QuantConv2d]
Expand All @@ -155,7 +155,7 @@ def main():
model = merge_batchnorm(model, graph, bottoms, targ_layer)

#create relations
if args.equalize or args.distill:
if args.equalize or args.distill_range:
res = create_relation(graph, bottoms, targ_layer)
if args.equalize:
cross_layer_equalization(graph, res, targ_layer, visualize_state=False)
Expand All @@ -173,10 +173,10 @@ def main():
bias_correction(graph, bottoms, targ_layer)

if args.quantize:
if not args.trainable and not args.distill:
graph = quantize_targ_layer(graph, 8, 16, targ_layer)
if not args.trainable and not args.distill_range:
graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer)

if args.distill:
if args.distill_range:
set_update_stat(model, [QuantMeasure], True)
model = update_quant_range(model.cuda(), data_distill, graph, bottoms)
set_update_stat(model, [QuantMeasure], False)
Expand Down
22 changes: 11 additions & 11 deletions main_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
parser.add_argument("--equalize", action='store_true')
parser.add_argument("--correction", action='store_true')
parser.add_argument("--absorption", action='store_true')
parser.add_argument("--distill", action='store_true')
parser.add_argument("--distill_range", action='store_true')
parser.add_argument("--log", action='store_true')
parser.add_argument("--relu", action='store_true')
parser.add_argument("--clip_weight", action='store_true')
Expand Down Expand Up @@ -178,7 +178,7 @@ def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_case

data = torch.ones((4, 3, 300, 300))

if args.distill:
if args.distill_range:
import copy
# define FP32 model
model_original = create_mobilenetv2_ssd_lite(len(class_names), width_mult=args.mb2_width_mult, is_test=True, quantize=args.quantize)
Expand All @@ -196,7 +196,7 @@ def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_case
transformer = TorchTransformer()
module_dict = {}
if args.quantize:
if args.distill:
if args.distill_range:
module_dict[1] = [(torch.nn.Conv2d, QConv2d), (torch.nn.Linear, QLinear)]
elif args.trainable:
module_dict[1] = [(torch.nn.Conv2d, QuantConv2d), (torch.nn.Linear, QuantLinear)]
Expand All @@ -215,7 +215,7 @@ def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_case
bottoms = transformer.log.getBottoms()
output_shape = transformer.log.getOutShapes()
if args.quantize:
if args.distill:
if args.distill_range:
targ_layer = [QConv2d, QLinear]
elif args.trainable:
targ_layer = [QuantConv2d, QuantLinear]
Expand All @@ -230,8 +230,8 @@ def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_case
net = merge_batchnorm(net, graph, bottoms, targ_layer)

#create relations
if args.equalize or args.distill:
res = create_relation(graph, bottoms, targ_layer, delete_single=not args.distill)
if args.equalize or args.distill_range:
res = create_relation(graph, bottoms, targ_layer, delete_single=not args.distill_range)
if args.equalize:
cross_layer_equalization(graph, res, targ_layer, visualize_state=False, converge_thres=2e-7, s_range=(1/args.equal_range, args.equal_range))

Expand All @@ -248,10 +248,10 @@ def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_case
bias_correction(graph, bottoms, targ_layer)

if args.quantize:
if not args.trainable and not args.distill:
graph = quantize_targ_layer(graph, 8, 16, targ_layer)
if not args.trainable and not args.distill_range:
graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer)

if args.distill:
if args.distill_range:
set_update_stat(net, [QuantMeasure], True)
net = update_quant_range(net.cuda(), data_distill, graph, bottoms, is_detection=True)
set_update_stat(net, [QuantMeasure], False)
Expand Down Expand Up @@ -335,7 +335,7 @@ def compute_average_precision_per_class(num_true_cases, gt_boxes, difficult_case
print(f"\nAverage Precision Across All Classes:{sum(aps)/len(aps)}")
if args.log:
with open("ssd_result.txt", 'a+') as ww:
ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill: {}\n".format(
args.dataset_type, args.quantize, args.relu, args.equalize, args.absorption, args.correction, args.clip_weight, args.distill
ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n".format(
args.dataset_type, args.quantize, args.relu, args.equalize, args.absorption, args.correction, args.clip_weight, args.distill_range
))
ww.write("mAP: {}\n\n".format(sum(aps)/len(aps)))
4 changes: 2 additions & 2 deletions utils/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def forward_all(net_inference, dataloader, visualize=False, opt=None):
print("FWIoU: {}".format(FWIoU))
if opt is not None:
with open("seg_result.txt", 'a+') as ww:
ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill: {}\n".format(
opt.dataset, opt.quantize, opt.relu, opt.equalize, opt.absorption, opt.correction, opt.clip_weight, opt.distill
ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n".format(
opt.dataset, opt.quantize, opt.relu, opt.equalize, opt.absorption, opt.correction, opt.clip_weight, opt.distill_range
))
ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(Acc, Acc_class, mIoU, FWIoU))

Expand Down

0 comments on commit 87ed0fb

Please sign in to comment.