forked from prophesier/diff-svc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
IceKyrin
authored and
IceKyrin
committed
Nov 11, 2022
1 parent
360c479
commit 675c726
Showing
7 changed files
with
183 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from pathlib import Path | ||
|
||
import librosa | ||
import numpy as np | ||
import torch | ||
from fairseq import checkpoint_utils | ||
|
||
|
||
def load_model(vec_path): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
print("load model(s) from {}".format(vec_path)) | ||
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | ||
[vec_path], | ||
suffix="", | ||
) | ||
model = models[0] | ||
model = model.to(device) | ||
model.eval() | ||
return model | ||
|
||
|
||
def get_vec_units(con_model, audio_path, dev): | ||
audio, sampling_rate = librosa.load(audio_path) | ||
if len(audio.shape) > 1: | ||
audio = librosa.to_mono(audio.transpose(1, 0)) | ||
if sampling_rate != 16000: | ||
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) | ||
|
||
feats = torch.from_numpy(audio).float() | ||
if feats.dim() == 2: # double channels | ||
feats = feats.mean(-1) | ||
assert feats.dim() == 1, feats.dim() | ||
feats = feats.view(1, -1) | ||
padding_mask = torch.BoolTensor(feats.shape).fill_(False) | ||
inputs = { | ||
"source": feats.to(dev), | ||
"padding_mask": padding_mask.to(dev), | ||
"output_layer": 9, # layer 9 | ||
} | ||
with torch.no_grad(): | ||
logits = con_model.extract_features(**inputs) | ||
feats = con_model.final_proj(logits[0]) | ||
return feats | ||
|
||
|
||
if __name__ == '__main__': | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model_path = "../../checkpoints/checkpoint_best_legacy_500.pt" # checkpoint_best_legacy_500.pt | ||
vec_model = load_model(model_path) | ||
# 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy | ||
file_lists = list(Path("../../data/vecfox").rglob('*.wav')) | ||
nums = len(file_lists) | ||
count = 0 | ||
for wav_path in file_lists: | ||
npy_path = wav_path.with_suffix(".npy") | ||
npy_content = get_vec_units(vec_model, str(wav_path), device).cpu().numpy()[0] | ||
np.save(str(npy_path), npy_content) | ||
count += 1 | ||
print(f"hubert process:{round(count * 100 / nums, 2)}%") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from argparse import ArgumentParser | ||
|
||
import torch | ||
|
||
|
||
def simplify_pth(pth_name, project_name): | ||
model_path = f'./checkpoints/{project_name}' | ||
checkpoint_dict = torch.load(f'{model_path}/{pth_name}') | ||
torch.save({'epoch': checkpoint_dict['epoch'], | ||
'state_dict': checkpoint_dict['state_dict'], | ||
'global_step': None, | ||
'checkpoint_callback_best': None, | ||
'optimizer_states': None, | ||
'lr_schedulers': None | ||
}, f'./clean_{pth_name}') | ||
|
||
|
||
def main(): | ||
parser = ArgumentParser() | ||
parser.add_argument('--proj', type=str) | ||
parser.add_argument('--steps', type=str) | ||
args = parser.parse_args() | ||
model_name = f"model_ckpt_steps_{args.steps}.ckpt" | ||
simplify_pth(model_name, args.proj) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
head_list = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] | ||
|
||
|
||
def trans_f0_seq(feature_pit, transform): | ||
feature_pit = feature_pit * 2 ** (transform / 12) | ||
return round(feature_pit, 1) | ||
|
||
|
||
def move_key(raw_data, mv_key): | ||
head = raw_data[:-1] | ||
body = int(raw_data[-1]) | ||
new_head_index = head_list.index(head) + mv_key | ||
while new_head_index < 0: | ||
body -= 1 | ||
new_head_index += 12 | ||
while new_head_index > 11: | ||
body += 1 | ||
new_head_index -= 12 | ||
result_data = head_list[new_head_index] + str(body) | ||
return result_data | ||
|
||
|
||
def trans_key(raw_data, key): | ||
for i in raw_data: | ||
note_seq_list = i["note_seq"].split(" ") | ||
new_note_seq_list = [] | ||
for note_seq in note_seq_list: | ||
if note_seq != "rest": | ||
new_note_seq = move_key(note_seq, key) | ||
new_note_seq_list.append(new_note_seq) | ||
else: | ||
new_note_seq_list.append(note_seq) | ||
i["note_seq"] = " ".join(new_note_seq_list) | ||
|
||
f0_seq_list = i["f0_seq"].split(" ") | ||
f0_seq_list = [float(x) for x in f0_seq_list] | ||
new_f0_seq_list = [] | ||
for f0_seq in f0_seq_list: | ||
new_f0_seq = trans_f0_seq(f0_seq, key) | ||
new_f0_seq_list.append(str(new_f0_seq)) | ||
i["f0_seq"] = " ".join(new_f0_seq_list) | ||
return raw_data | ||
|
||
|
||
key = -6 | ||
f_w = open("raw.txt", "w", encoding='utf-8') | ||
with open("result.txt", "r", encoding='utf-8') as f: | ||
raw_data = f.readlines() | ||
for raw in raw_data: | ||
raw_list = raw.split("|") | ||
new_note_seq_list = [] | ||
for note_seq in raw_list[3].split(" "): | ||
if note_seq != "rest": | ||
note_seq = note_seq.split("/")[0] if "/" in note_seq else note_seq | ||
new_note_seq = move_key(note_seq, key) | ||
new_note_seq_list.append(new_note_seq) | ||
else: | ||
new_note_seq_list.append(note_seq) | ||
raw_list[3] = " ".join(new_note_seq_list) | ||
f_w.write("|".join(raw_list)) | ||
f_w.close() |