forked from graviraja/MLOps-Basics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_model_to_onnx.py
57 lines (47 loc) · 1.81 KB
/
convert_model_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import hydra
import logging
from omegaconf.omegaconf import OmegaConf
from model import ColaModel
from data import DataModule
logger = logging.getLogger(__name__)
@hydra.main(config_path="./configs", config_name="config")
def convert_model(cfg):
root_dir = hydra.utils.get_original_cwd()
model_path = f"{root_dir}/models/best-checkpoint.ckpt"
logger.info(f"Loading pre-trained model from: {model_path}")
cola_model = ColaModel.load_from_checkpoint(model_path)
data_model = DataModule(
cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length
)
data_model.prepare_data()
data_model.setup()
input_batch = next(iter(data_model.train_dataloader()))
input_sample = {
"input_ids": input_batch["input_ids"][0].unsqueeze(0),
"attention_mask": input_batch["attention_mask"][0].unsqueeze(0),
}
# Export the model
logger.info(f"Converting the model into ONNX format")
torch.onnx.export(
cola_model, # model being run
(
input_sample["input_ids"],
input_sample["attention_mask"],
), # model input (or a tuple for multiple inputs)
f"{root_dir}/models/model.onnx", # where to save the model (can be a file or file-like object)
export_params=True,
opset_version=10,
input_names=["input_ids", "attention_mask"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input_ids": {0: "batch_size"}, # variable length axes
"attention_mask": {0: "batch_size"},
"output": {0: "batch_size"},
},
)
logger.info(
f"Model converted successfully. ONNX format model is at: {root_dir}/models/model.onnx"
)
if __name__ == "__main__":
convert_model()