Skip to content

Commit e8a1be4

Browse files
authored
Merge pull request teticio#61 from nnyj/onnx
Add tf_to_onnx script
2 parents ae119cc + 6af6c86 commit e8a1be4

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

train/requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ gensim
44
huggingface-hub
55
pandas
66
lightning
7+
onnx
8+
onnxruntime
9+
protobuf==3.19.6
710
pyyaml
811
requests
912
spotipy

train/tf_to_onnx.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import argparse
2+
import os
3+
from collections import OrderedDict
4+
from typing import Optional
5+
6+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
7+
8+
import numpy as np
9+
import tensorflow as tf
10+
import torch
11+
import onnxruntime as ort
12+
from audiodiffusion.audio_encoder import AudioEncoder
13+
from keras.models import load_model
14+
from torch import Tensor
15+
16+
if __name__ == "__main__":
17+
"""
18+
Entry point for the tf_to_onnx script.
19+
20+
Converts a TensorFlow MP3ToVec model to a ONNX MP3ToVec model.
21+
22+
Args:
23+
--onnx_model_file (str): Path to the ONNX model file. Default is "models/speccy_model.onnx".
24+
--tf_model_file (str): Path to the TensorFlow model file. Default is "models/speccymodel".
25+
26+
Returns:
27+
None
28+
"""
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument(
31+
"--onnx_model_file",
32+
type=str,
33+
default="models/speccy_model.onnx",
34+
help="ONNX model path",
35+
)
36+
parser.add_argument(
37+
"--tf_model_file",
38+
type=str,
39+
default="models/speccy_model",
40+
help="TensorFlow model path",
41+
)
42+
args = parser.parse_args()
43+
44+
model: Optional[tf.keras.Model] = load_model(
45+
args.tf_model_file,
46+
custom_objects={"cosine_proximity": tf.compat.v1.keras.losses.cosine_proximity},
47+
)
48+
if model is None:
49+
raise ValueError("Model did not load correctly.")
50+
51+
pytorch_model = AudioEncoder()
52+
new_state_dict = OrderedDict()
53+
for conv_block in range(3):
54+
new_state_dict[f"conv_blocks.{conv_block}.sep_conv.depthwise.weight"] = Tensor(
55+
model.get_layer(
56+
f"separable_conv2d_{conv_block + 1}"
57+
).depthwise_kernel.numpy()
58+
).permute(2, 3, 0, 1)
59+
new_state_dict[f"conv_blocks.{conv_block}.sep_conv.pointwise.weight"] = Tensor(
60+
model.get_layer(
61+
f"separable_conv2d_{conv_block + 1}"
62+
).pointwise_kernel.numpy()
63+
).permute(3, 2, 0, 1)
64+
new_state_dict[f"conv_blocks.{conv_block}.sep_conv.pointwise.bias"] = Tensor(
65+
model.get_layer(f"separable_conv2d_{conv_block + 1}").bias.numpy()
66+
)
67+
new_state_dict[f"conv_blocks.{conv_block}.batch_norm.weight"] = Tensor(
68+
model.get_layer(f"batch_normalization_{conv_block + 1}").gamma.numpy()
69+
)
70+
new_state_dict[f"conv_blocks.{conv_block}.batch_norm.running_mean"] = Tensor(
71+
model.get_layer(f"batch_normalization_{conv_block + 1}").moving_mean.numpy()
72+
)
73+
new_state_dict[f"conv_blocks.{conv_block}.batch_norm.running_var"] = Tensor(
74+
model.get_layer(
75+
f"batch_normalization_{conv_block + 1}"
76+
).moving_variance.numpy()
77+
)
78+
new_state_dict[f"conv_blocks.{conv_block}.batch_norm.bias"] = Tensor(
79+
model.get_layer(f"batch_normalization_{conv_block + 1}").beta.numpy()
80+
)
81+
82+
new_state_dict[f"dense_block.batch_norm.weight"] = Tensor(
83+
model.get_layer(f"batch_normalization_{conv_block + 2}").gamma.numpy() # type: ignore
84+
)
85+
new_state_dict[f"dense_block.batch_norm.running_mean"] = Tensor(
86+
model.get_layer(f"batch_normalization_{conv_block + 2}").moving_mean.numpy() # type: ignore
87+
)
88+
new_state_dict[f"dense_block.batch_norm.running_var"] = Tensor(
89+
model.get_layer(f"batch_normalization_{conv_block + 2}").moving_variance.numpy() # type: ignore
90+
)
91+
new_state_dict[f"dense_block.batch_norm.bias"] = Tensor(
92+
model.get_layer(f"batch_normalization_{conv_block + 2}").beta.numpy() # type: ignore
93+
)
94+
95+
new_state_dict[f"dense_block.dense.weight"] = Tensor(
96+
model.get_layer(f"dense_1").kernel.numpy()
97+
).permute(1, 0)
98+
new_state_dict[f"dense_block.dense.bias"] = Tensor(
99+
model.get_layer(f"dense_1").bias.numpy()
100+
)
101+
new_state_dict[f"embedding.weight"] = Tensor(
102+
model.get_layer(f"dense_2").kernel.numpy()
103+
).permute(1, 0)
104+
new_state_dict[f"embedding.bias"] = Tensor(model.get_layer(f"dense_2").bias.numpy())
105+
106+
pytorch_model.eval()
107+
pytorch_model.load_state_dict(new_state_dict, strict=False)
108+
109+
dummy_input = torch.randn(1, 1, 96, 216)
110+
dynamic_axes = {'input': {0: 'batch_size'}, # variable length axes
111+
'output' : {0 : 'batch_size'}} # Map dynamic axis to its name
112+
torch.onnx.export(pytorch_model,
113+
dummy_input,
114+
args.onnx_model_file,
115+
input_names = ['input'],
116+
output_names = ['output'],
117+
dynamic_axes=dynamic_axes)
118+
119+
# test
120+
np.random.seed(42)
121+
ort_session = ort.InferenceSession(args.onnx_model_file, providers=["CPUExecutionProvider"])
122+
example = np.random.random_sample((1, 96, 216, 1))
123+
with torch.no_grad():
124+
assert (
125+
np.abs(
126+
ort_session.run(None, {"input": Tensor(example).permute(0, 3, 1, 2).numpy()})
127+
- model(example).numpy()
128+
).max()
129+
< 2e-3
130+
)

0 commit comments

Comments
 (0)