Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler for transformer shape_bucket failed #1270

Open
xiangweizeng opened this issue Nov 18, 2024 · 1 comment
Open

Compiler for transformer shape_bucket failed #1270

xiangweizeng opened this issue Nov 18, 2024 · 1 comment

Comments

@xiangweizeng
Copy link

转换一个Transformer 模型,采用 ShapeBucket时候失败,不采用ShapeBucket可以正常转换。
错误为:
nncase 2.4/2.9版本结果均一致

Binary_229_Unary_104_Binary_228_Unary_103_Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230_Conv2D
Unhandled exception. System.AggregateException: One or more errors occurred. (Value cannot be null. (Parameter 'key'))
 ---> System.ArgumentNullException: Value cannot be null. (Parameter 'key')
   at System.Collections.Generic.Dictionary`2.TryInsert(TKey key, TValue value, InsertionBehavior behavior)
   at System.Linq.Enumerable.ToDictionary[TSource,TKey,TElement](IEnumerable`1 source, Func`2 keySelector, Func`2 elementSelector, IEqualityComparer`1 comparer)
   at Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper.MakeVarValuesForAllSegment(ShapeBucketOptions options, Int32 segmentCount, Boolean staticShape)
   at Nncase.Passes.Rules.ShapeBucket.RecordFusionShape.RunCoreAsync(BaseFunction main, RunPassContext context)
   at Nncase.Passes.Pass`2.RunAsync(TInput input, RunPassContext context)
   at Nncase.Passes.PassManager.FunctionPassGroup.Runner.RunAsync()
   at Nncase.Passes.PassManager.FunctionPassGroup.RunAsync(IRModule module)
   at Nncase.Passes.PassManager.RunAsync(IRModule module)
   at Nncase.Compiler.Compiler.RunPassAsync(Action`1 register, String name, IProgress`1 progress, CancellationToken token)
   at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
   --- End of inner exception stack trace ---
   at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
   at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)

转换代码:

import os
import shutil

import nncase
import numpy as np
import onnx
import onnxsim


def generate_data_encoder(data_dir, input_shapes, data_count):
    data = [[]]
    for i in range(data_count):
        x_batch = np.fromfile(os.path.join(data_dir, 'X_{}.bin'.format(i)), dtype='int64').reshape(input_shapes[0])
        data[0].append(x_batch)
    return data


def parse_model_input_output(model_file, input_shapes_):
    onnx_model = onnx.load(model_file)
    input_all = [node.name for node in onnx_model.graph.input]
    input_initializer = [node.name for node in onnx_model.graph.initializer]
    input_names = list(set(input_all) - set(input_initializer))
    input_tensors = [
        node for node in onnx_model.graph.input if node.name in input_names]

    # input
    inputs = []
    for i, e in enumerate(input_tensors):
        onnx_type = e.type.tensor_type
        input_dict = {
            'name': e.name,
            'dtype': onnx.helper.tensor_dtype_to_np_dtype(onnx_type.elem_type),
            'shape': input_shapes_[i]
        }
        inputs.append(input_dict)
    return onnx_model, inputs


def onnx_simplify(model_file, dump_dir, input_shapes_):
    onnx_model, inputs = parse_model_input_output(model_file, input_shapes_)
    onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
    input_shapes = {}
    for input in inputs:
        input_shapes[input['name']] = input['shape']

    onnx_model, check = onnxsim.simplify(onnx_model, overwrite_input_shapes=input_shapes, )
    print(onnx.helper.printable_graph(onnx_model.graph))
    assert check, "Simplified ONNX model could not be validated"

    model_file = os.path.join(dump_dir, 'simplified.onnx')
    onnx.save_model(onnx_model, model_file)
    return model_file


def read_model_file(model_file):
    with open(model_file, 'rb') as f:
        model_content = f.read()
    return model_content


def encoder_tokmodel(onnx_model_path, kmodel_path, data_dir, ptq_option, input_shapes, data_count, tmp_path,
                     target='k230'):
    if not os.path.exists(tmp_path):
        os.makedirs(tmp_path)

    # onnx simplify
    model_file = onnx_simplify(onnx_model_path, tmp_path, input_shapes)

    # compile_options
    compile_options = nncase.CompileOptions()
    compile_options.target = target
    compile_options.preprocess = False
    compile_options.dump_ir = True
    compile_options.dump_asm = True
    compile_options.dump_dir = tmp_path

    compile_options.shape_bucket_enable = True
    compile_options.shape_bucket_range_info = {"seq_len": [1, 64]}
    compile_options.shape_bucket_segments_count = 64
    compile_options.shape_bucket_fix_var_map = {"batch_size": 1}

    # compiler
    compiler = nncase.Compiler(compile_options)

    # import
    model_content = read_model_file(model_file)
    import_options = nncase.ImportOptions()
    compiler.import_onnx(model_content, import_options)

    # ptq_options
    ptq_options = nncase.PTQTensorOptions()
    ptq_options.samples_count = data_count
    if ptq_option == 0:
        pass
    elif ptq_option == 1:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.w_quant_type = 'int16'
    elif ptq_option == 2:
        ptq_options.calibrate_method = 'NoClip'
        ptq_options.quant_type = 'int16'
    elif ptq_option == 3:
        ptq_options.w_quant_type = 'int16'
    elif ptq_option == 4:
        ptq_options.quant_type = 'int16'
    ptq_options.set_tensor_data(generate_data_encoder(data_dir, input_shapes, data_count))
    compiler.use_ptq(ptq_options)
    # compile
    compiler.compile()

    # model
    kmodel = compiler.gencode_tobytes()
    with open(kmodel_path, 'wb') as f:
        f.write(kmodel)
    if os.path.exists(tmp_path):
        shutil.rmtree(tmp_path)


if __name__ == "__main__":
    encoder_tokmodel(onnx_model_path="onnx/example.onnx",
                     kmodel_path="onnx/example.kmodel",
                     data_dir="generate_data",
                     ptq_option=0,
                     input_shapes=[[1, 64]],
                     data_count=30,
                     tmp_path='./tmp')
```


运行日志
```sh
Merge Binary_106_Unary_74_Binary_105_Unary_73
Binary_108_Binary_107
Merge Binary_106_Unary_74_Binary_105_Unary_73_Binary_108_Binary_107
Conv2D_16
Conv2D_17
Conv2D_18
Merge Reshape_233
Binary_111_Binary_109_Binary_110
Binary_114_Binary_112_Binary_113
268
Merge Reshape_233_Binary_111_Binary_109_Binary_110_Binary_114_Binary_112_Binary_113
Reshape_236_Concat_235
Merge Reshape_234
Reshape_233_Binary_111_Binary_109_Binary_110_Binary_114_Binary_112_Binary_113_Reshape_236_Concat_235
Merge Conv2D_19_MatMul_1
Binary_115
Merge Binary_117_Unary_76_Binary_116_Unary_75
Conv2D_22_Conv2D_20_Conv2D_21_Binary_119_Binary_118
Merge Binary_117_Unary_76_Binary_116_Unary_75_Conv2D_22_Conv2D_20_Conv2D_21_Binary_119_Binary_118
Binary_120
Merge Conv2D_19_MatMul_1_Binary_115
Binary_117_Unary_76_Binary_116_Unary_75_Conv2D_22_Conv2D_20_Conv2D_21_Binary_119_Binary_118_Binary_1
Merge Binary_122_Unary_78_Binary_121_Unary_77
Binary_124_Binary_123
Merge Binary_122_Unary_78_Binary_121_Unary_77_Binary_124_Binary_123
Conv2D_23
Conv2D_24
Conv2D_25
Merge Reshape_237
Binary_127_Binary_125_Binary_126
Binary_130_Binary_128_Binary_129
277
Merge Reshape_237_Binary_127_Binary_125_Binary_126_Binary_130_Binary_128_Binary_129
Reshape_240_Concat_239
Merge Reshape_238
Reshape_237_Binary_127_Binary_125_Binary_126_Binary_130_Binary_128_Binary_129_Reshape_240_Concat_239
Merge Conv2D_26_MatMul_3
Binary_131
Merge Binary_133_Unary_80_Binary_132_Unary_79
Conv2D_29_Conv2D_27_Conv2D_28_Binary_135_Binary_134
Merge Binary_133_Unary_80_Binary_132_Unary_79_Conv2D_29_Conv2D_27_Conv2D_28_Binary_135_Binary_134
Binary_136
Merge Conv2D_26_MatMul_3_Binary_131
Binary_133_Unary_80_Binary_132_Unary_79_Conv2D_29_Conv2D_27_Conv2D_28_Binary_135_Binary_134_Binary_1
Merge Binary_138_Unary_82_Binary_137_Unary_81
Binary_140_Binary_139
Merge Binary_138_Unary_82_Binary_137_Unary_81_Binary_140_Binary_139
Conv2D_30
Conv2D_31
Conv2D_32
Merge Reshape_241
Binary_143_Binary_141_Binary_142
Binary_146_Binary_144_Binary_145
286
Merge Reshape_241_Binary_143_Binary_141_Binary_142_Binary_146_Binary_144_Binary_145
Reshape_244_Concat_243
Merge Reshape_242
Reshape_241_Binary_143_Binary_141_Binary_142_Binary_146_Binary_144_Binary_145_Reshape_244_Concat_243
Merge Conv2D_33_MatMul_5
Binary_147
Merge Binary_149_Unary_84_Binary_148_Unary_83
Conv2D_36_Conv2D_34_Conv2D_35_Binary_151_Binary_150
Merge Binary_149_Unary_84_Binary_148_Unary_83_Conv2D_36_Conv2D_34_Conv2D_35_Binary_151_Binary_150
Binary_152
Merge Conv2D_33_MatMul_5_Binary_147
Binary_149_Unary_84_Binary_148_Unary_83_Conv2D_36_Conv2D_34_Conv2D_35_Binary_151_Binary_150_Binary_1
Merge Binary_154_Unary_86_Binary_153_Unary_85
Binary_156_Binary_155
Merge Binary_154_Unary_86_Binary_153_Unary_85_Binary_156_Binary_155
Conv2D_37
Conv2D_38
Conv2D_39
Merge Reshape_245
Binary_159_Binary_157_Binary_158
Binary_162_Binary_160_Binary_161
295
Merge Reshape_245_Binary_159_Binary_157_Binary_158_Binary_162_Binary_160_Binary_161
Reshape_248_Concat_247
Merge Reshape_246
Reshape_245_Binary_159_Binary_157_Binary_158_Binary_162_Binary_160_Binary_161_Reshape_248_Concat_247
Merge Conv2D_40_MatMul_7
Binary_163
Merge Binary_165_Unary_88_Binary_164_Unary_87
Conv2D_43_Conv2D_41_Conv2D_42_Binary_167_Binary_166
Merge Binary_165_Unary_88_Binary_164_Unary_87_Conv2D_43_Conv2D_41_Conv2D_42_Binary_167_Binary_166
Binary_168
Merge Conv2D_40_MatMul_7_Binary_163
Binary_165_Unary_88_Binary_164_Unary_87_Conv2D_43_Conv2D_41_Conv2D_42_Binary_167_Binary_166_Binary_1
Merge Binary_170_Unary_90_Binary_169_Unary_89
Binary_172_Binary_171
Merge Binary_170_Unary_90_Binary_169_Unary_89_Binary_172_Binary_171
Conv2D_44
Conv2D_45
Conv2D_46
Merge Reshape_249
Binary_175_Binary_173_Binary_174
Binary_178_Binary_176_Binary_177
304
Merge Reshape_249_Binary_175_Binary_173_Binary_174_Binary_178_Binary_176_Binary_177
Reshape_252_Concat_251
Merge Reshape_250
Reshape_249_Binary_175_Binary_173_Binary_174_Binary_178_Binary_176_Binary_177_Reshape_252_Concat_251
Merge Conv2D_47_MatMul_9
Binary_179
Merge Binary_181_Unary_92_Binary_180_Unary_91
Conv2D_50_Conv2D_48_Conv2D_49_Binary_183_Binary_182
Merge Binary_181_Unary_92_Binary_180_Unary_91_Conv2D_50_Conv2D_48_Conv2D_49_Binary_183_Binary_182
Binary_184
Merge Conv2D_47_MatMul_9_Binary_179
Binary_181_Unary_92_Binary_180_Unary_91_Conv2D_50_Conv2D_48_Conv2D_49_Binary_183_Binary_182_Binary_1
Merge Binary_186_Unary_94_Binary_185_Unary_93
Binary_188_Binary_187
Merge Binary_186_Unary_94_Binary_185_Unary_93_Binary_188_Binary_187
Conv2D_51
Conv2D_52
Conv2D_53
Merge Reshape_253
Binary_191_Binary_189_Binary_190
Binary_194_Binary_192_Binary_193
313
Merge Reshape_253_Binary_191_Binary_189_Binary_190_Binary_194_Binary_192_Binary_193
Reshape_256_Concat_255
Merge Reshape_254
Reshape_253_Binary_191_Binary_189_Binary_190_Binary_194_Binary_192_Binary_193_Reshape_256_Concat_255
Merge Conv2D_54_MatMul_11
Binary_195
Merge Binary_197_Unary_96_Binary_196_Unary_95
Conv2D_57_Conv2D_55_Conv2D_56_Binary_199_Binary_198
Merge Binary_197_Unary_96_Binary_196_Unary_95_Conv2D_57_Conv2D_55_Conv2D_56_Binary_199_Binary_198
Binary_200
Merge Conv2D_54_MatMul_11_Binary_195
Binary_197_Unary_96_Binary_196_Unary_95_Conv2D_57_Conv2D_55_Conv2D_56_Binary_199_Binary_198_Binary_2
Merge Binary_202_Unary_98_Binary_201_Unary_97
Binary_204_Binary_203
Merge Binary_202_Unary_98_Binary_201_Unary_97_Binary_204_Binary_203
Conv2D_58
Conv2D_59
Conv2D_60
Merge Reshape_257
Binary_207_Binary_205_Binary_206
Binary_210_Binary_208_Binary_209
322
Merge Reshape_257_Binary_207_Binary_205_Binary_206_Binary_210_Binary_208_Binary_209
Reshape_260_Concat_259
Merge Reshape_258
Reshape_257_Binary_207_Binary_205_Binary_206_Binary_210_Binary_208_Binary_209_Reshape_260_Concat_259
Merge Conv2D_61_MatMul_13
Binary_211
Merge Binary_213_Unary_100_Binary_212_Unary_99
Conv2D_64_Conv2D_62_Conv2D_63_Binary_215_Binary_214
Merge Binary_213_Unary_100_Binary_212_Unary_99_Conv2D_64_Conv2D_62_Conv2D_63_Binary_215_Binary_214
Binary_216
Merge Conv2D_61_MatMul_13_Binary_211
Binary_213_Unary_100_Binary_212_Unary_99_Conv2D_64_Conv2D_62_Conv2D_63_Binary_215_Binary_214_Binary_
Merge Binary_218_Unary_102_Binary_217_Unary_101
Binary_220_Binary_219
Merge Binary_218_Unary_102_Binary_217_Unary_101_Binary_220_Binary_219
Conv2D_65
Conv2D_66
Conv2D_67
Merge Reshape_261
Binary_223_Binary_221_Binary_222
Binary_226_Binary_224_Binary_225
331
Merge Reshape_261_Binary_223_Binary_221_Binary_222_Binary_226_Binary_224_Binary_225
Reshape_264_Concat_263
Merge Reshape_262
Reshape_261_Binary_223_Binary_221_Binary_222_Binary_226_Binary_224_Binary_225_Reshape_264_Concat_263
Merge Conv2D_68_MatMul_15
Binary_227
Merge Binary_229_Unary_104_Binary_228_Unary_103
Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230
Merge Binary_229_Unary_104_Binary_228_Unary_103_Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230
Conv2D_72
Merge Conv2D_68_MatMul_15_Binary_227
Binary_229_Unary_104_Binary_228_Unary_103_Conv2D_71_Conv2D_69_Conv2D_70_Binary_231_Binary_230_Conv2D
Unhandled exception. System.AggregateException: One or more errors occurred. (Value cannot be null. (Parameter 'key'))
 ---> System.ArgumentNullException: Value cannot be null. (Parameter 'key')
   at System.Collections.Generic.Dictionary`2.TryInsert(TKey key, TValue value, InsertionBehavior behavior)
   at System.Linq.Enumerable.ToDictionary[TSource,TKey,TElement](IEnumerable`1 source, Func`2 keySelector, Func`2 elementSelector, IEqualityComparer`1 comparer)
   at Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper.MakeVarValuesForAllSegment(ShapeBucketOptions options, Int32 segmentCount, Boolean staticShape)
   at Nncase.Passes.Rules.ShapeBucket.RecordFusionShape.RunCoreAsync(BaseFunction main, RunPassContext context)
   at Nncase.Passes.Pass`2.RunAsync(TInput input, RunPassContext context)
   at Nncase.Passes.PassManager.FunctionPassGroup.Runner.RunAsync()
   at Nncase.Passes.PassManager.FunctionPassGroup.RunAsync(IRModule module)
   at Nncase.Passes.PassManager.RunAsync(IRModule module)
   at Nncase.Compiler.Compiler.RunPassAsync(Action`1 register, String name, IProgress`1 progress, CancellationToken token)
   at Nncase.Compiler.Compiler.CompileAsync(IProgress`1 progress, CancellationToken token)
   --- End of inner exception stack trace ---
   at System.Threading.Tasks.Task.Wait(Int32 millisecondsTimeout, CancellationToken cancellationToken)
   at Nncase.Compiler.Interop.CApi.CompilerCompile(IntPtr compilerHandle)

```
@curioyang
Copy link
Member

@xiangweizeng 你的range不够分64段的,把段数降低一点试试

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants