Skip to content

Commit

Permalink
Test also modified_beam_search (k2-fsa#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Feb 15, 2023
1 parent 74c41b9 commit bde310a
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 101 deletions.
211 changes: 124 additions & 87 deletions .github/scripts/run-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,41 @@ $repo/test_wavs/2.wav
)

for wave in ${waves[@]}; do
time $EXE \
$repo/v2/tokens.txt \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/v2/tokens.txt \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$wave \
4 \
$m
done
done

log "Test int8 models"

for wave in ${waves[@]}; do
time $EXE \
$repo/v2/tokens.txt \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.param \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.bin \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.param \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/v2/tokens.txt \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.param \
$repo/v2/encoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.bin \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.param \
$repo/v2/decoder_jit_trace-pnnx-epoch-15-avg-3.ncnn.bin \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.param \
$repo/v2/joiner_jit_trace-pnnx-epoch-15-avg-3.ncnn.int8.bin \
$wave \
4 \
$m
done
done

rm -rf $repo
Expand All @@ -88,15 +100,21 @@ $repo/test_wavs/2.wav
)

for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.param \
$repo/encoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.bin \
$repo/decoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.param \
$repo/decoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.bin \
$repo/joiner_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.param \
$repo/joiner_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.param \
$repo/encoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.bin \
$repo/decoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.param \
$repo/decoder_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.bin \
$repo/joiner_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.param \
$repo/joiner_jit_trace-v2-epoch-11-avg-2-pnnx.ncnn.bin \
$wave \
4 \
$m
done
done

rm -rf $repo
Expand Down Expand Up @@ -124,15 +142,21 @@ $repo/test_wavs/1221-135766-0002.wav
)

for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/bar/encoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.param \
$repo/bar/encoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.bin \
$repo/bar/decoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.param \
$repo/bar/decoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.bin \
$repo/bar/joiner_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.param \
$repo/bar/joiner_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/bar/encoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.param \
$repo/bar/encoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.bin \
$repo/bar/decoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.param \
$repo/bar/decoder_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.bin \
$repo/bar/joiner_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.param \
$repo/bar/joiner_jit_trace-v2-iter-468000-avg-16-pnnx.ncnn.bin \
$wave \
4 \
$m
done
done

rm -rf $repo
Expand Down Expand Up @@ -162,46 +186,43 @@ $repo/test_wavs/1221-135766-0002.wav
)

for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$wave
done

log "Test beam-search"
for m in greedy_search modified_beam_search; do
log "----test $m ---"

for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$wave \
4 \
"modified_beam_search"
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$wave \
4 \
$m
done
done



log "Test int8 models"

for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.param \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.bin \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.param \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.param \
$repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.bin \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
$repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.param \
$repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.int8.bin \
$wave \
4 \
$m
done
done

rm -rf $repo
Expand Down Expand Up @@ -232,28 +253,40 @@ $repo/test_wavs/4.wav
)

for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
$repo/encoder_jit_trace-pnnx.ncnn.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.param \
$repo/joiner_jit_trace-pnnx.ncnn.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
$repo/encoder_jit_trace-pnnx.ncnn.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.param \
$repo/joiner_jit_trace-pnnx.ncnn.bin \
$wave \
4 \
$m
done
done

log "test int8 models"
for wave in ${waves[@]}; do
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.int8.param \
$repo/encoder_jit_trace-pnnx.ncnn.int8.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.int8.param \
$repo/joiner_jit_trace-pnnx.ncnn.int8.bin \
$wave
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.int8.param \
$repo/encoder_jit_trace-pnnx.ncnn.int8.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.int8.param \
$repo/joiner_jit_trace-pnnx.ncnn.int8.bin \
$wave \
4 \
$m
done
done

rm -rf $repo
Expand Down Expand Up @@ -281,6 +314,8 @@ $repo/test_wavs/4.wav

for wave in ${waves[@]}; do
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
Expand Down Expand Up @@ -318,6 +353,8 @@ $repo/test_wavs/1221-135766-0002.wav

for wave in ${waves[@]}; do
for m in greedy_search modified_beam_search; do
log "----test $m ---"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
Expand Down
2 changes: 1 addition & 1 deletion sherpa-ncnn/csrc/conv-emformer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
joiner_ex->input(joiner_input_indexes_[1], decoder_out);

ncnn::Mat joiner_out;
joiner_ex->extract("out0", joiner_out);
joiner_ex->extract(joiner_output_indexes_[0], joiner_out);
return joiner_out;
}

Expand Down
2 changes: 2 additions & 0 deletions sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ void ModifiedBeamSearchDecoder::Decode() {
cur.Clear();

ncnn::Mat decoder_input = BuildDecoderInput(prev);

ncnn::Mat decoder_out = RunDecoder2D(model_, decoder_input);

// decoder_out.w == decoder_dim
// decoder_out.h == num_active_paths

Expand Down
16 changes: 3 additions & 13 deletions sherpa-ncnn/csrc/zipformer-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ ZipformerModel::ZipformerModel(const ModelConfig &config) {
encoder_.opt.use_fp16_arithmetic = false;
encoder_.opt.use_fp16_storage = false;

decoder_.opt.use_fp16_arithmetic = false;
decoder_.opt.use_fp16_storage = false;

joiner_.opt.use_fp16_arithmetic = false;
joiner_.opt.use_fp16_storage = false;
NCNN_LOGE("Disable fp16 for zipformer");
NCNN_LOGE("Disable fp16 for Zipformer encoder");

bool has_gpu = false;
#if NCNN_VULKAN
Expand Down Expand Up @@ -64,12 +59,7 @@ ZipformerModel::ZipformerModel(AAssetManager *mgr, const ModelConfig &config) {
encoder_.opt.use_fp16_arithmetic = false;
encoder_.opt.use_fp16_storage = false;

decoder_.opt.use_fp16_arithmetic = false;
decoder_.opt.use_fp16_storage = false;

joiner_.opt.use_fp16_arithmetic = false;
joiner_.opt.use_fp16_storage = false;
NCNN_LOGE("Disable fp16 for Zipformer on Android");
NCNN_LOGE("Disable fp16 for Zipformer encoder on Android");

bool has_gpu = false;
#if NCNN_VULKAN
Expand Down Expand Up @@ -172,7 +162,7 @@ ncnn::Mat ZipformerModel::RunJoiner(ncnn::Mat &encoder_out,
joiner_ex->input(joiner_input_indexes_[1], decoder_out);

ncnn::Mat joiner_out;
joiner_ex->extract("out0", joiner_out);
joiner_ex->extract(joiner_output_indexes_[0], joiner_out);
return joiner_out;
}

Expand Down

0 comments on commit bde310a

Please sign in to comment.