forked from SkyTNT/midi-model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_checkpoint.py
28 lines (23 loc) · 916 Bytes
/
check_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# %%
import torch
from midi_model import MIDIModel
from midi_tokenizer import MIDITokenizer
from collections import OrderedDict
ckpt = torch.load('checkpoints/model.ckpt')
for k in ckpt.keys():
print(k)
# %%
# Load pre-trained weights into the models
def load_pretrained_weights(model, pretrained_weights_dict, model_name):
new_state_dict = OrderedDict()
for key, value in pretrained_weights_dict.items():
# Modify the keys from 'net.embed_tokens.weight' to 'model_name.embed_tokens.weight'
if key in model_name:
new_state_dict[key] = value
model.load_state_dict(new_state_dict, strict=False)
return model
net_token_model = MIDITokenizer()
net_token_model = load_pretrained_weights(net_token_model, ckpt, ['net_token'])
net_model = MIDIModel(net_token_model)
# Load pre-trained weights into the models
load_pretrained_weights(net_model, ckpt, ['net', 'lm_head'])