Skip to content

Commit

Permalink
Clear fluid api and fix tests (PaddlePaddle#1641)
Browse files Browse the repository at this point in the history
* remove fluid apis.

* fix hpo.

* fix asp.
  • Loading branch information
zzjjay authored Jan 31, 2023
1 parent b248f20 commit 92874cc
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 83 deletions.
18 changes: 12 additions & 6 deletions demo/models/pvanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.nn.initializer import KaimingUniform
import os, sys, time, math
import numpy as np
from collections import namedtuple

BLOCK_TYPE_MCRELU = 'BLOCK_TYPE_MCRELU'
Expand Down Expand Up @@ -458,15 +455,24 @@ def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1):
abs_geo_diff = paddle.abs(geo_diff)
l_flag = l_score >= 1
l_flag = paddle.cast(x=l_flag, dtype="float32")
l_flag = fluid.layers.expand(x=l_flag, expand_times=[1, channels, 1, 1])
l_flag = paddle.expand(
x=l_flag,
shape=[
l_flag.shape[0], l_flag.shape[1] * channels, l_flag.shape[2],
l_flag.shape[3]
])

smooth_l1_sign = abs_geo_diff < l_flag
smooth_l1_sign = paddle.cast(x=smooth_l1_sign, dtype="float32")

in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + (
abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
l_short_edge = fluid.layers.expand(
x=l_short_edge, expand_times=[1, channels, 1, 1])
l_short_edge = paddle.expand(
x=l_short_edge,
shape=[
l_short_edge.shape[0], l_short_edge.shape[1] * channels,
l_short_edge.shape[2], l_short_edge.shape[3]
])
out_loss = l_short_edge * in_loss * l_flag
out_loss = out_loss * l_flag
smooth_l1_loss = paddle.mean(out_loss)
Expand Down
6 changes: 3 additions & 3 deletions demo/quant/pact_quant_aware/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
from paddle.fluid.layer_helper import LayerHelper
from paddle.common_ops_import import LayerHelper
quantization_model_save_dir = './quantization_models/'

_logger = get_logger(__name__, level=logging.INFO)
Expand Down Expand Up @@ -146,8 +146,8 @@ def compress(args):
raise ValueError("{} is not supported.".format(args.data))

image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
if args.use_pact:
Expand Down
58 changes: 28 additions & 30 deletions demo/quant/quant_embedding/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,31 @@
neural network for word2vec
"""
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F


def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
def skip_gram_word2vec(dict_size,
embedding_size,
batch_size,
is_sparse=False,
neg_num=5):

datas = []
words = []
input_word = paddle.static.data(
name="input_word", shape=[None, 1], dtype='int64')
true_word = paddle.static.data(
name='true_label', shape=[None, 1], dtype='int64')
neg_word = paddle.static.data(
name="neg_label", shape=[None, neg_num], dtype='int64')

datas.append(input_word)
datas.append(true_word)
datas.append(neg_word)
words.append(input_word)
words.append(true_word)
words.append(neg_word)

py_reader = fluid.layers.create_py_reader_by_data(
capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True)
py_reader = paddle.io.DataLoader.from_generator(
capacity=64, feed_list=words, use_double_buffer=True, iterable=False)

words = fluid.layers.read_file(py_reader)
words[0] = paddle.reshape(words[0], [-1])
words[1] = paddle.reshape(words[1], [-1])
init_width = 0.5 / embedding_size
Expand Down Expand Up @@ -72,40 +73,37 @@ def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
input=neg_word_reshape,
is_sparse=is_sparse,
size=[dict_size, embedding_size],
param_attr=paddle.ParamAttr(
name='emb_w', learning_rate=1.0))
param_attr=paddle.ParamAttr(name='emb_w', learning_rate=1.0))

neg_emb_w_re = paddle.reshape(
neg_emb_w, shape=[-1, neg_num, embedding_size])
neg_emb_b = paddle.static.nn.embedding(
input=neg_word_reshape,
is_sparse=is_sparse,
size=[dict_size, 1],
param_attr=paddle.ParamAttr(
name='emb_b', learning_rate=1.0))
param_attr=paddle.ParamAttr(name='emb_b', learning_rate=1.0))

neg_emb_b_vec = paddle.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = paddle.add(paddle.mean(
paddle.multiply(input_emb, true_emb_w), keepdim=True),
true_emb_b)
true_logits = paddle.add(
paddle.mean(paddle.multiply(input_emb, true_emb_w), keepdim=True),
true_emb_b)
input_emb_re = paddle.reshape(input_emb, shape=[-1, 1, embedding_size])
neg_matmul = paddle.matmul(input_emb_re, neg_emb_w_re, transpose_y=True)
neg_matmul_re = paddle.reshape(neg_matmul, shape=[-1, neg_num])
neg_logits = paddle.add(neg_matmul_re, neg_emb_b_vec)
#nce loss

# TODO: replaced by paddle.tensor.creation.fill_constant_batch_size_like
label_ones = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, 1], value=1.0, dtype='float32')
label_zeros = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')

true_xent = paddle.nn.functional.binary_cross_entropy(true_logits,
label_ones)
neg_xent = paddle.nn.functional.binary_cross_entropy(neg_logits,
label_zeros)
cost = paddle.add(paddle.sum(true_xent, axis=1),
paddle.sum(neg_xent, axis=1))
label_ones = paddle.full(
shape=[batch_size, 1], fill_value=1.0, dtype='float32')
label_zeros = paddle.full(
shape=[batch_size, neg_num], fill_value=0.0, dtype='float32')

true_xent = F.binary_cross_entropy_with_logits(
true_logits, label_ones, reduction='none')
neg_xent = F.binary_cross_entropy_with_logits(
neg_logits, label_zeros, reduction='none')
cost = paddle.add(
paddle.sum(true_xent, axis=1), paddle.sum(neg_xent, axis=1))
avg_cost = paddle.mean(cost)
return avg_cost, py_reader

Expand Down
5 changes: 3 additions & 2 deletions demo/quant/quant_embedding/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __reader__():
def train_loop(args, train_program, reader, py_reader, loss, trainer_id, weight,
lr):

py_reader.decorate_tensor_provider(
py_reader.set_batch_generator(
convert_python_to_tensor(weight, args.batch_size, reader.train()))

place = paddle.CPUPlace()
Expand Down Expand Up @@ -213,6 +213,7 @@ def train(args):
loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size,
args.embedding_size,
args.batch_size,
is_sparse=args.is_sparse,
neg_num=args.nce_num)

Expand All @@ -223,7 +224,7 @@ def train(args):

optimizer.minimize(loss)

# do local training
# do local training
logger.info("run local training")
main_program = paddle.static.default_main_program()
train_loop(args, main_program, word2vec_reader, py_reader, loss, 0,
Expand Down
24 changes: 10 additions & 14 deletions paddleslim/auto_compression/create_compressed_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ def _create_optimizer(train_config):
### build optimizer
optim_params = optimizer_builder['optimizer']
optim_type = optim_params.pop('type')
opt = getattr(optimizer, optim_type)(learning_rate=lr,
grad_clip=grad_clip,
weight_decay=reg,
**optim_params)
opt = getattr(optimizer, optim_type)(
learning_rate=lr, grad_clip=grad_clip, weight_decay=reg, **optim_params)
return opt, lr


Expand Down Expand Up @@ -160,8 +158,8 @@ def _parse_distill_loss(distill_node_pair,
for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
distill_lambda):
tmp_loss = losses.get(loss_clas, 0.0)
_logger.info("train config.distill_node_pair: {}".format(
node, loss_clas, lam))
_logger.info(
"train config.distill_node_pair: {}".format(node, loss_clas, lam))
assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2):
Expand Down Expand Up @@ -529,9 +527,7 @@ def build_prune_program(executor,
original_shapes = {}
for param in train_program_info.program.global_block(
).all_parameters():
if config[
'prune_params_name'] is not None and param.name in config[
'prune_params_name']:
if config['prune_params_name'] is not None and param.name in config['prune_params_name']:
params.append(param.name)
original_shapes[param.name] = param.shape

Expand All @@ -541,9 +537,8 @@ def build_prune_program(executor,
train_program_info.program,
paddle.static.global_scope(),
params=params,
ratios=[config['pruned_ratio']] * len(params)
if isinstance(config['pruned_ratio'], float) else
config['pruned_ratio'],
ratios=[config['pruned_ratio']] * len(params) if isinstance(
config['pruned_ratio'], float) else config['pruned_ratio'],
place=place)
_logger.info(
"####################channel pruning##########################")
Expand Down Expand Up @@ -577,8 +572,9 @@ def build_prune_program(executor,
pruner.add_supported_layer(param.name)
if "teacher_" in param.name:
excluded_params_name.append(param.name)
pruner.set_excluded_layers(train_program_info.program,
excluded_params_name)
pruner.set_excluded_layers(
main_program=train_program_info.program,
param_names=excluded_params_name)
elif strategy.startswith('transformer_prune'):
from .transformer_pruner import TransformerPruner
assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
Expand Down
2 changes: 1 addition & 1 deletion paddleslim/quant/post_quant_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self,
"""QuantConfig init"""
self.executor = executor
self.place = place
self.float_infer_model_path = float_infer_model_path
self.float_infer_model_path = float_infer_model_path.rstrip('/')
self.quantize_model_path = quantize_model_path
self.algo = algo,
self.hist_percent = hist_percent,
Expand Down
49 changes: 29 additions & 20 deletions paddleslim/quant/reconstruction_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from ..core.graph_wrapper import GraphWrapper
from ..common import get_logger

__all__ = ['ReconstructionQuantization', ]
__all__ = [
'ReconstructionQuantization',
]

_logger = get_logger(
__name__,
Expand Down Expand Up @@ -91,7 +93,8 @@ def _preparation(self):
batch_id = 0
with utils.tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
bar_format=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
for data in self._data_loader():
self._executor.run(
Expand All @@ -111,7 +114,8 @@ def _sampling_threshold(self):
batch_id = 0
with utils.tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
bar_format=
'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80, ) as t:
for data in self._data_loader():
self._executor.run(
Expand Down Expand Up @@ -237,7 +241,7 @@ def __init__(self,
return a batch every time.
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load
scope(static.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
place(CPUPlace()|CUDAPlace(N)): This parameter represents
paddle run on which device.
Expand Down Expand Up @@ -385,8 +389,8 @@ def _run(self):

with paddle.static.program_guard(tmp_program, startup_program):
student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" +
quant_op_out_name)
teacher_var = tmp_program.global_block().var(
"teacher_" + quant_op_out_name)
total_loss, recon_loss, round_loss = loss_function.get_loss(
student_var,
teacher_var, )
Expand Down Expand Up @@ -471,7 +475,8 @@ def _dequant(x, scale):
shape=weight.shape,
dtype=weight.dtype,
name=weight.name + ".alpha",
default_initializer=paddle.nn.initializer.Assign(self._alpha, ), )
default_initializer=paddle.nn.initializer.Assign(
self._alpha, ), )

h_v = paddle.clip(
paddle.nn.functional.sigmoid(v) * (ZETA - GAMMA) + GAMMA,
Expand All @@ -483,13 +488,14 @@ def _dequant(x, scale):
dtype=weight.dtype,
shape=weight.shape,
name=weight.name + '.scale',
default_initializer=paddle.nn.initializer.Assign(scale, ))
default_initializer=paddle.nn.initializer.Assign(
scale, ))
else:
scale_var = scale

quantized_weight = _quant(weight_copy, scale_var)
floor_weight = (paddle.floor(quantized_weight) - quantized_weight
).detach() + quantized_weight
floor_weight = (paddle.floor(quantized_weight) -
quantized_weight).detach() + quantized_weight
clip_weight = paddle.clip(floor_weight + h_v, -bnt, bnt)
w = _dequant(clip_weight, scale_var)
return w
Expand Down Expand Up @@ -525,8 +531,9 @@ def _drop_quant_dequant(self, inputs, scale):

def _insert_drop_quant_dequant(self):
for op in self._graph.ops():
if op.type(
) in ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2']:
if op.type() in [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
if op.type() in ['conv2d', 'depthwise_conv2d']:
if op.inputs("Filter")[0].name().startswith("teacher"):
break
Expand Down Expand Up @@ -670,8 +677,8 @@ def _insert_func(self, var, scale, func):
'X': var._var,
'Y': op.input('Y')[0] + '.qdrop',
}
elif _type == 'scale' and op.input('X')[
0] == inputs.name + '.tmp':
elif _type == 'scale' and op.input(
'X')[0] == inputs.name + '.tmp':
_inputs = {'X': var._var}
else:
_inputs = {'X': op.input('X')[0] + '.qdrop'}
Expand All @@ -687,11 +694,13 @@ def _insert_func(self, var, scale, func):
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]:
continue
if op.type() in ['conv2d', 'depthwise_conv2d'] and op.inputs(
'Filter')[0].name().startswith('teacher'):
if op.type() in [
'conv2d', 'depthwise_conv2d'
] and op.inputs('Filter')[0].name().startswith('teacher'):
continue
if op.type() in ['mul', 'matmul', 'matmul_v2'] and op.inputs('Y')[
0].name().startswith('teacher'):
if op.type() in [
'mul', 'matmul', 'matmul_v2'
] and op.inputs('Y')[0].name().startswith('teacher'):
continue
if func == '_soft_rounding':
op._op._rename_input(inputs.name, out.name + '.rounding')
Expand Down Expand Up @@ -964,8 +973,8 @@ def _find_coherent_ep(op):
else:
future_ep = _find_multi_input_ep(ep)

if future_ep is None or self._depth[future_ep.idx()] - self._depth[
sp.idx()] >= limit:
if future_ep is None or self._depth[future_ep.idx(
)] - self._depth[sp.idx()] >= limit:
return self._create_region(sp, ep)
ep = future_ep

Expand Down
Loading

0 comments on commit 92874cc

Please sign in to comment.