Skip to content

Commit

Permalink
Adapt to api updates. (PaddlePaddle#1629)
Browse files Browse the repository at this point in the history
* Adapt to api updates.

* fix bugs.
  • Loading branch information
zzjjay authored Jan 16, 2023
1 parent 98900b3 commit d7efdad
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion example/auto_compression/detection/paddle_inference_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def eval(predictor, val_loader, metric, rerun_flag=False):
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
boxes_num = predictor.get_output_handle(output_names[1])
if FLAGS.include_nms:
boxes_num = predictor.get_output_handle(output_names[1])
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
for i, _ in enumerate(input_names):
Expand Down
7 changes: 4 additions & 3 deletions example/auto_compression/detection/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def main():
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
if global_config.get('input_list') is None:
global_config['input_list'] = get_feed_vars(
global_config['model_dir'], global_config['model_filename'],
global_config['params_filename'])
train_loader = reader_wrapper(train_loader, global_config['input_list'])

if 'Evaluation' in global_config.keys() and global_config[
Expand Down
3 changes: 2 additions & 1 deletion paddleslim/analysis/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def save_cls_model(model, input_shape, save_dir, data_type):
batch_nums=1,
weight_bits=8,
activation_bits=8,
quantizable_op_type=["conv2d", "depthwise_conv2d"])
quantizable_op_type=["conv2d", "depthwise_conv2d"],
onnx_format=False)

model_file = os.path.join(quantize_model_path, 'model.pdmodel')
param_file = os.path.join(quantize_model_path, 'model.pdiparams')
Expand Down

0 comments on commit d7efdad

Please sign in to comment.