Skip to content

Commit

Permalink
fix quant aware distributed train (PaddlePaddle#1206)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Jun 29, 2022
1 parent f275cef commit 919a9b1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
16 changes: 13 additions & 3 deletions demo/quant/pact_quant_aware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(

### 训练命令

普通量化:
- 普通量化:
```
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact False
Expand All @@ -177,14 +177,24 @@ python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/Mob
```
可以看到普通量化loss不稳定,而且在实验进行到2个epoch时,loss会变为nan。普通量化很不稳定

使用PACT量化训练

- 使用PACT量化训练
```
# 先分析MobileNetV3模型激活值分布,来初始化PACT截断阈值
python train.py --analysis=True
# 启动PACT量化训练
```

单卡启动PACT量化训练:
```
export CUDA_VISIBLE_DEVICES=0
python train.py
```

多卡启动PACT量化训练:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 train.py --batch_size=64
```

输出结果为
```
2020-06-05 15:25:37,647-INFO: epoch[0]-batch[10] - loss: 1.60160636902; acc_top1: 0.65625; acc_top5: 0.890625; time: 1.56788897514
Expand Down
21 changes: 13 additions & 8 deletions demo/quant/pact_quant_aware/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def create_optimizer(args):
return cosine_decay(args)


def _prepare_envs():
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
_logger.info(f"devices: {devices}")
exe = paddle.static.Executor(places)
return exe, places


def compress(args):
num_workers = 4
shuffle = True
Expand Down Expand Up @@ -158,10 +166,7 @@ def compress(args):
learning_rate, opt = create_optimizer(args)
opt.minimize(avg_cost)

place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
exe = paddle.static.Executor(place)
exe, places = _prepare_envs()
exe.run(paddle.static.default_startup_program())

train_loader = paddle.io.DataLoader(
Expand All @@ -177,7 +182,7 @@ def compress(args):

valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
places=places,
feed_list=[image, label],
drop_last=False,
return_list=False,
Expand Down Expand Up @@ -290,7 +295,7 @@ def get_optimizer():

val_program = quant_aware(
val_program,
place,
places,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
Expand All @@ -299,7 +304,7 @@ def get_optimizer():
for_test=True)
compiled_train_prog = quant_aware(
train_prog,
place,
places,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
Expand Down Expand Up @@ -420,7 +425,7 @@ def train(epoch, compiled_train_prog, lr):
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
float_program, int8_program = convert(val_program, place, quant_config, \
float_program, int8_program = convert(val_program, places, quant_config, \
scope=None, \
save_int8=True)
_logger.info("eval best_model after convert")
Expand Down
13 changes: 13 additions & 0 deletions demo/quant/quant_aware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,20 @@ compiled_train_prog = compiled_train_prog.with_data_parallel(

### 训练命令

- 单卡启动:

```
export CUDA_VISIBLE_DEVICES=0
python train.py --model MobileNet --pretrained_model ./pretrain/MobileNetV1_pretrained --checkpoint_dir ./output/mobilenetv1 --num_epochs 30
```

- 多卡启动:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 train.py \
--model MobileNet \
--pretrained_model ./pretrain/MobileNetV1_pretrained \
--checkpoint_dir ./output/mobilenetv1 \
--num_epochs 30
```

运行之后,可看到``best_model``的最后测试结果,和MobileNet量化前的精度top1=70.99%, top5=89.68%非常相近。
26 changes: 15 additions & 11 deletions demo/quant/quant_aware/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.")
add_arg('pretrained_model', str, "./pretrain/MobileNetV1_pretrained", "Whether to use pretrained model.")
add_arg('lr', float, 0.0001, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
Expand Down Expand Up @@ -84,6 +84,14 @@ def create_optimizer(args):
return cosine_decay(args)


def _prepare_envs():
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
_logger.info(f"devices: {devices}")
exe = paddle.static.Executor(places)
return exe, places


def compress(args):
num_workers = 4
shuffle = True
Expand Down Expand Up @@ -161,21 +169,20 @@ def compress(args):
train_prog = paddle.static.default_main_program()
val_program = paddle.static.default_main_program().clone(for_test=True)

place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe, places = _prepare_envs()
############################################################################################################
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
############################################################################################################
val_program = quant_aware(
val_program, place, quant_config, scope=None, for_test=True)
val_program, places, quant_config, scope=None, for_test=True)
compiled_train_prog = quant_aware(
train_prog, place, quant_config, scope=None, for_test=False)
train_prog, places, quant_config, scope=None, for_test=False)
opt = create_optimizer(args)
opt.minimize(avg_cost)

exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())

if pretrain:
Expand All @@ -185,9 +192,6 @@ def compress(args):
if args.pretrained_model:
paddle.static.load(train_prog, args.pretrained_model, exe)

places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()

train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
Expand All @@ -200,7 +204,7 @@ def compress(args):
num_workers=num_workers)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
places=places,
feed_list=[image, label],
drop_last=False,
return_list=False,
Expand Down Expand Up @@ -290,7 +294,7 @@ def train(epoch, compiled_train_prog):
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
############################################################################################################
float_program, int8_program = convert(val_program, place, quant_config, \
float_program, int8_program = convert(val_program, places, quant_config, \
scope=None, \
save_int8=True,
onnx_format=args.onnx_format)
Expand Down

0 comments on commit 919a9b1

Please sign in to comment.