forked from playdasegunda/band-split-rope-transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
114 lines (94 loc) · 2.96 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import typing as tp
from pathlib import Path
import torch
import soundfile as sf
from omegaconf import OmegaConf
from data import EvalSourceSeparationDataset
from separator import Separator
class InferenceProgram:
SAVED_MODELS_DIR = Path("./saved_models")
def __init__(
self,
in_path: str,
out_path: str,
target: str,
ckpt_path: tp.Optional[str] = None,
device: str = 'cuda'
):
self.tgt_dir = self.SAVED_MODELS_DIR / target
# path to checkpoint
if ckpt_path is None:
ckpt_path = self.tgt_dir / f"{target}.pt"
if not ckpt_path.is_file():
raise ValueError("{ckpt_path} is missing. Please provide 'ckpt_path' explicitly.")
self.ckpt_path = ckpt_path
# config params
self.cfg_path = self.tgt_dir / 'hparams.yaml'
self.cfg = OmegaConf.load(self.cfg_path)
self.cfg.audio_params['in_fp'] = in_path
self.cfg.audio_params['out_fp'] = out_path
self.device = torch.device(
'cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu'
)
# initialize the dataset
self.dataset = EvalSourceSeparationDataset(mode='inference', **self.cfg.audio_params)
# initialize the separator
self.sep = Separator(self.cfg, self.ckpt_path)
_ = self.sep.eval()
_ = self.sep.to(self.device)
def run(self) -> None:
y = self.dataset.load_file(self.cfg.audio_params['in_fp'])
out_fp = self.dataset.out_fp
y = y.to(self.device)
# apply separator to the mixture file
y_hat = self.sep(y).cpu()
# save file as .wav
sf.write(out_fp, y_hat.T, samplerate=44100)
return None
def main(args) -> None:
program = InferenceProgram(**args)
program.run()
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-i',
'--in-path',
type=str,
required=True,
help="Path to the input directory/file with .wav/.mp3 extensions."
)
parser.add_argument(
'-o',
'--out-path',
type=str,
required=True,
help="Path to the output directory. Files will be saved in .wav format with sr=44100."
)
parser.add_argument(
'-t',
'--target',
type=str,
required=False,
default='vocals',
help="Name of the target source to extract. "
)
parser.add_argument(
'-c',
'--ckpt-path',
type=str,
required=False,
default=None,
help="Path to model's checkpoint. If not specified, the .ckpt from SAVED_MODELS_DIR/{target} is used."
)
parser.add_argument(
'-d',
'--device',
type=str,
required=False,
default='cuda',
help="Device name - either 'cuda', or 'cpu'."
)
args = vars(parser.parse_args())
main(args)