From b735a396bcb3b51222f6e5bf8ef93ddef001602a Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 24 Mar 2023 13:41:07 +0800 Subject: [PATCH] Fix pruning to support reshape2 for bias (#1700) --- paddleslim/nas/ofa/layers.py | 81 ++++++++++++++---------------- paddleslim/prune/prune_worker.py | 6 ++- paddleslim/prune/pruner.py | 17 +++++-- tests/dygraph/test_prune_walker.py | 1 + tests/test_dy2prog.py | 16 +++--- 5 files changed, 66 insertions(+), 55 deletions(-) diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py index 1162e921d..f12054aec 100644 --- a/paddleslim/nas/ofa/layers.py +++ b/paddleslim/nas/ofa/layers.py @@ -208,8 +208,7 @@ def get_active_filter(self, in_nc, out_nc, kernel_size): filters = self.weight else: filters = self.weight[:out_nc, :in_nc, start:end, start:end] - if self.transform_kernel != False and kernel_size < self._kernel_size[ - 0]: + if self.transform_kernel != False and kernel_size < self._kernel_size[0]: ### if transform kernel, then use matrix to transform start_filter = self.weight[:out_nc, :in_nc, :, :] for i in range(len(self.ks_set) - 1, 0, -1): @@ -223,10 +222,11 @@ def get_active_filter(self, in_nc, out_nc, kernel_size): _input_filter, shape=[(_input_filter.shape[0] * _input_filter.shape[1]), -1]) - _input_filter = paddle.matmul( - _input_filter, - self.__getattr__('%dto%d_matrix' % - (src_ks, target_ks)), False, False) + _input_filter = paddle.matmul(_input_filter, + self.__getattr__( + '%dto%d_matrix' % + (src_ks, target_ks)), False, + False) _input_filter = paddle.reshape( _input_filter, shape=[ @@ -279,11 +279,11 @@ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None): out_nc = int(channel) else: out_nc = self._out_channels - ks = int(self._kernel_size[0]) if kernel_size == None else int( - kernel_size) + ks = int( + self._kernel_size[0]) if kernel_size == None else int(kernel_size) - groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, - out_nc) + groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc( + in_nc, out_nc) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) @@ -293,7 +293,7 @@ def forward(self, input, kernel_size=None, expand_ratio=None, channel=None): padding = self._padding if self.bias is not None: - ### if conv is depthwise conv, expand_ratio=0, but conv' expand + ### if conv is depthwise conv, expand_ratio=0, but conv' expand ### ratio before depthwise conv is not equal to 1.0, the shape of the weight ### about this depthwise conv is changed, but out_nc is not change, ### so need to change bias shape according to the weight_out_nc. @@ -513,8 +513,7 @@ def __init__(self, def get_active_filter(self, in_nc, out_nc, kernel_size): start, end = compute_start_end(self._kernel_size[0], kernel_size) filters = self.weight[:in_nc, :out_nc, start:end, start:end] - if self.transform_kernel != False and kernel_size < self._kernel_size[ - 0]: + if self.transform_kernel != False and kernel_size < self._kernel_size[0]: start_filter = self.weight[:in_nc, :out_nc, :, :] for i in range(len(self.ks_set) - 1, 0, -1): src_ks = self.ks_set[i] @@ -527,10 +526,11 @@ def get_active_filter(self, in_nc, out_nc, kernel_size): _input_filter, shape=[(_input_filter.shape[0] * _input_filter.shape[1]), -1]) - _input_filter = paddle.matmul( - _input_filter, - self.__getattr__('%dto%d_matrix' % - (src_ks, target_ks)), False, False) + _input_filter = paddle.matmul(_input_filter, + self.__getattr__( + '%dto%d_matrix' % + (src_ks, target_ks)), False, + False) _input_filter = paddle.reshape( _input_filter, shape=[ @@ -590,11 +590,11 @@ def forward(self, else: out_nc = self._out_channels - ks = int(self._kernel_size[0]) if kernel_size == None else int( - kernel_size) + ks = int( + self._kernel_size[0]) if kernel_size == None else int(kernel_size) - groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, - out_nc) + groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc( + in_nc, out_nc) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) @@ -731,8 +731,8 @@ def __init__(self, 'expand_ratio'] if 'expand_ratio' in candidate_config else None self.base_output_dim = self.conv[0]._out_channels if self.expand_ratio != None: - self.base_output_dim = int(self.conv[0]._out_channels / - max(self.expand_ratio)) + self.base_output_dim = int( + self.conv[0]._out_channels / max(self.expand_ratio)) def forward(self, input, expand_ratio=None, channel=None): """ @@ -863,8 +863,8 @@ def __init__(self, 'expand_ratio'] if 'expand_ratio' in candidate_config else None self.base_output_dim = self._out_features if self.expand_ratio != None: - self.base_output_dim = int(self._out_features / - max(self.expand_ratio)) + self.base_output_dim = int( + self._out_features / max(self.expand_ratio)) def forward(self, input, expand_ratio=None, channel=None): """ @@ -941,9 +941,9 @@ def __init__(self, data_format='NCHW', use_global_stats=None, name=None): - super(SuperBatchNorm2D, self).__init__( - num_features, momentum, epsilon, weight_attr, bias_attr, - data_format, use_global_stats, name) + super(SuperBatchNorm2D, + self).__init__(num_features, momentum, epsilon, weight_attr, + bias_attr, data_format, use_global_stats, name) self.cur_config = None def forward(self, input): @@ -1047,8 +1047,7 @@ def forward(self, input): "Variance": [variance] } - helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper( - 'batch_norm') + helper = paddle.fluid.layer_helper.LayerHelper('batch_norm') param_dtype = input.dtype if input.dtype != 'float16' else 'float32' saved_mean = helper.create_variable_for_type_inference( @@ -1150,8 +1149,7 @@ def forward(self, input): "Variance": [self._variance] } - helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper( - 'sync_batch_norm') + helper = paddle.fluid.layer_helper.LayerHelper('sync_batch_norm') saved_mean = helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) @@ -1211,9 +1209,9 @@ def __init__(self, bias_attr=None, data_format='NCHW', name=None): - super(SuperInstanceNorm2D, self).__init__(num_features, epsilon, - momentum, weight_attr, - bias_attr, data_format, name) + super(SuperInstanceNorm2D, + self).__init__(num_features, epsilon, momentum, weight_attr, + bias_attr, data_format, name) self.cur_config = None def forward(self, input): @@ -1319,8 +1317,7 @@ def forward(self, input): "begin_norm_axis": begin_norm_axis } - helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper( - 'layer_norm') + helper = paddle.fluid.layer_helper.LayerHelper('layer_norm') dtype = input.dtype mean_out = helper.create_variable_for_type_inference( @@ -1399,17 +1396,17 @@ def __init__(self, sparse=False, weight_attr=None, name=None): - super(SuperEmbedding, self).__init__(num_embeddings, embedding_dim, - padding_idx, sparse, weight_attr, - name) + super(SuperEmbedding, + self).__init__(num_embeddings, embedding_dim, padding_idx, sparse, + weight_attr, name) self.candidate_config = candidate_config self.cur_config = None self.expand_ratio = candidate_config[ 'expand_ratio'] if 'expand_ratio' in candidate_config else None self.base_output_dim = self._embedding_dim if self.expand_ratio != None: - self.base_output_dim = int(self._embedding_dim / - max(self.expand_ratio)) + self.base_output_dim = int( + self._embedding_dim / max(self.expand_ratio)) def forward(self, input, expand_ratio=None, channel=None): """ diff --git a/paddleslim/prune/prune_worker.py b/paddleslim/prune/prune_worker.py index 0b28e2b38..ce1a99893 100644 --- a/paddleslim/prune/prune_worker.py +++ b/paddleslim/prune/prune_worker.py @@ -233,7 +233,7 @@ def _prune(self, var, pruned_axis, transforms): assert self._valid_reshape2( shape), "we don't support the shape {} in pruning".format(shape) # assert self._valid_pruned_axis(shape, pruned_axis), "we don't support pruned axis is {} when shape is changing from {} to {}".format(pruned_axis, in_shape, out_shape) - self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms) + # self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms) if var in self.op.inputs("X"): if (len(out_shape) > len(in_shape)): #self.op.set_attr('shape', @@ -254,6 +254,10 @@ def _prune(self, var, pruned_axis, transforms): #self.op.set_attr('shape', # [0, 0, int(shape[2] * 0.875), shape[3]]) transform = {"repeat": out_shape[pruned_axis + 1]} + elif len(in_shape) == 1 and len( + out_shape) == 4 and out_shape[pruned_axis] == in_shape[0]: + transform = {} + self.append_pruned_vars(in_var, 0, transforms) else: transform = {} self._visit_and_search(in_var, pruned_axis, diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index d8242f17a..592f277df 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -50,6 +50,14 @@ def __init__(self, criterion="l1_norm", self.pruned_weights = False + def _update_reshape_op(self, param: VarWrapper, op: OpWrapper, new_shape): + if op.type() == 'reshape2': + _param_shape = param.shape() + _shape_attr = op.attr('shape') + if len(_param_shape) == 1 and _param_shape[0] == _shape_attr[1]: + _shape_attr[1] = new_shape[0] + op.set_attr("shape", _shape_attr) + def prune(self, program, scope, @@ -111,8 +119,8 @@ def prune(self, merge_pruned_params[param][pruned_axis].append(pruned_idx) for param_name in merge_pruned_params: for pruned_axis in merge_pruned_params[param_name]: - pruned_idx = np.concatenate(merge_pruned_params[param_name][ - pruned_axis]) + pruned_idx = np.concatenate( + merge_pruned_params[param_name][pruned_axis]) param = graph.var(param_name) _groups = 1 if not lazy: @@ -138,6 +146,7 @@ def prune(self, param_shape_backup[param.name()] = origin_shape new_shape = list(param.shape()) new_shape[pruned_axis] -= len(pruned_idx) + self._update_reshape_op(param, op, new_shape) param.set_shape(new_shape) if not only_graph and (_groups == 1 or pruned_axis != 1): @@ -159,8 +168,8 @@ def prune(self, except IndexError as e: _logger.error( "Pruning {} with shape {} on axis {}, but get [{}]; ". - format(param.name(), - param_t.shape(), pruned_axis, e)) + format(param.name(), param_t.shape(), pruned_axis, + e)) graph.infer_shape() self.pruned_weights = (not only_graph) diff --git a/tests/dygraph/test_prune_walker.py b/tests/dygraph/test_prune_walker.py index 03741e8bb..c73cb6968 100644 --- a/tests/dygraph/test_prune_walker.py +++ b/tests/dygraph/test_prune_walker.py @@ -25,6 +25,7 @@ def runTest(self): x = np.random.uniform(-1, 1, x_shape).astype('float32') pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)]) pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0) + net(paddle.to_tensor(x)) self.assertTrue(net.linear.weight.shape == [5400, 5]) diff --git a/tests/test_dy2prog.py b/tests/test_dy2prog.py index 14902bd79..986d8e298 100644 --- a/tests/test_dy2prog.py +++ b/tests/test_dy2prog.py @@ -32,8 +32,8 @@ def setUp(self): def prepare_inputs(self): self.inputs = [3, 28, 28] self.ops = [ - 'assign_value', 'reshape2', 'conv2d', 'elementwise_add', 'pool2d', - 'reshape2', 'matmul_v2', 'elementwise_add' + 'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add', + 'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add' ] def prepare_layer(self): @@ -51,8 +51,8 @@ class TestEagerDygraph2Program2(TestEagerDygraph2Program): def prepare_inputs(self): self.inputs = [[3, 28, 28]] self.ops = [ - 'assign_value', 'reshape2', 'conv2d', 'elementwise_add', 'pool2d', - 'reshape2', 'matmul_v2', 'elementwise_add' + 'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add', + 'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add' ] @@ -60,8 +60,8 @@ class TestEagerDygraph2Program3(TestEagerDygraph2Program): def prepare_inputs(self): self.inputs = paddle.randn([3, 28, 28]) self.ops = [ - 'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2', - 'matmul_v2', 'elementwise_add' + 'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d', + 'reshape2', 'matmul_v2', 'elementwise_add' ] @@ -69,8 +69,8 @@ class TestEagerDygraph2Program4(TestEagerDygraph2Program): def prepare_inputs(self): self.inputs = [paddle.randn([3, 28, 28])] self.ops = [ - 'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2', - 'matmul_v2', 'elementwise_add' + 'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d', + 'reshape2', 'matmul_v2', 'elementwise_add' ]