Skip to content

Commit 3530688

Browse files
authored
Make Emotion2vec support onnx (#2359)
* Make emotion2vec exportable to onnx * Make export_meta of emotion2vec consistence with other models * Include layer norm in the exported onnx model
1 parent d4f13c2 commit 3530688

File tree

4 files changed

+86
-2
lines changed

4 files changed

+86
-2
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ GPT-SoVITS*
2727
modelscope_models
2828
examples/aishell/llm_asr_nar/*
2929
*egg-info
30+
env/

export.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from funasr import AutoModel
33

44
model = AutoModel(
5-
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
5+
model="iic/emotion2vec_base",
6+
hub="ms"
67
)
78

8-
res = model.export(type="onnx", quantize=False, opset_version=13, device='cuda') # fp32 onnx-gpu
9+
res = model.export(type="onnx", quantize=False, opset_version=13, device='cpu') # fp32 onnx-gpu
910
# res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu
+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python3
2+
# -*- encoding: utf-8 -*-
3+
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4+
# MIT License (https://opensource.org/licenses/MIT)
5+
6+
import types
7+
import torch
8+
import torch.nn.functional as F
9+
10+
11+
def export_rebuild_model(model, **kwargs):
12+
model.device = kwargs.get("device")
13+
14+
# store original forward since self.extract_features is calling it
15+
model._original_forward = model.forward
16+
17+
model.forward = types.MethodType(export_forward, model)
18+
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
19+
model.export_input_names = types.MethodType(export_input_names, model)
20+
model.export_output_names = types.MethodType(export_output_names, model)
21+
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
22+
model.export_name = types.MethodType(export_name, model)
23+
24+
model.export_name = "emotion2vec"
25+
return model
26+
27+
28+
def export_forward(
29+
self, x: torch.Tensor
30+
):
31+
with torch.no_grad():
32+
if self.cfg.normalize:
33+
mean = torch.mean(x, dim=1, keepdim=True)
34+
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
35+
x = (x - mean) / torch.sqrt(var + 1e-5)
36+
x = x.view(x.shape[0], -1)
37+
38+
# Call the original forward directly just like extract_features
39+
# Cannot directly use self.extract_features since it is being replaced by export_forward
40+
res = self._original_forward(
41+
source=x,
42+
padding_mask=None,
43+
mask=False,
44+
features_only=True,
45+
remove_extra_tokens=True
46+
)
47+
48+
x = res["x"]
49+
50+
return x
51+
52+
53+
def export_dummy_inputs(self):
54+
return (torch.randn(1, 16000),)
55+
56+
57+
def export_input_names(self):
58+
return ["input"]
59+
60+
61+
def export_output_names(self):
62+
return ["output"]
63+
64+
65+
def export_dynamic_axes(self):
66+
return {
67+
"input": {
68+
0: "batch_size",
69+
1: "sequence_length",
70+
},
71+
"output": {0: "batch_size", 1: "sequence_length"},
72+
}
73+
74+
75+
def export_name(self):
76+
return "emotion2vec"

funasr/models/emotion2vec/model.py

+6
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,9 @@ def inference(
265265
results.append(result_i)
266266

267267
return results, meta_data
268+
269+
def export(self, **kwargs):
270+
from .export_meta import export_rebuild_model
271+
272+
models = export_rebuild_model(model=self, **kwargs)
273+
return models

0 commit comments

Comments
 (0)