forked from fishaudio/Bert-VITS2
-
Notifications
You must be signed in to change notification settings - Fork 8
/
asr_transcript.py
109 lines (94 loc) · 3.77 KB
/
asr_transcript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import concurrent.futures
import os
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from tqdm import tqdm
from tools.log import logger
os.environ["MODELSCOPE_CACHE"] = "./"
def transcribe_worker(file_path: str, inference_pipeline, language):
"""
Worker function for transcribing a segment of an audio file.
"""
lab_path = os.path.splitext(file_path)[0] + '.lab'
if os.path.exists(lab_path) and os.path.isfile(lab_path):
logger.info(f'{lab_path}为已转写的文本,跳过~')
with open(lab_path, 'r', encoding='utf-8') as f:
text = f.read()
return text
rec_result = inference_pipeline(audio_in=file_path)
text = str(rec_result.get("text", "")).strip()
text_without_spaces = text.replace(" ", "")
logger.info(file_path)
if language != "EN":
logger.info("text: " + text_without_spaces)
return text_without_spaces
else:
logger.info("text: " + text)
return text
def transcribe_folder_parallel(folder_path, language, max_workers=4):
"""
Transcribe all .wav files in the given folder using ThreadPoolExecutor.
"""
logger.info(f"parallel transcribe: {folder_path}|{language}|{max_workers}")
if language == "JP":
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline",
)
for _ in range(max_workers)
]
elif language == "ZH":
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v1.2.4",
)
for _ in range(max_workers)
]
else:
workers = [
pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline",
)
for _ in range(max_workers)
]
file_paths = []
langs = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.lower().endswith(".wav"):
file_path = os.path.join(root, file)
file_paths.append(file_path)
langs.append(language)
all_workers = (
workers * (len(file_paths) // max_workers)
+ workers[: len(file_paths) % max_workers]
)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for i in tqdm(range(0, len(file_paths), max_workers), desc="转写进度: "):
l, r = i, min(i + max_workers, len(file_paths))
transcriptions = list(
executor.map(
transcribe_worker, file_paths[l:r], all_workers[l:r], langs[l:r]
)
)
for file_path, transcription in zip(file_paths[l:r], transcriptions):
if transcription:
lab_file_path = os.path.splitext(file_path)[0] + ".lab"
with open(lab_file_path, "w", encoding="utf-8") as lab_file:
lab_file.write(transcription)
logger.info("已经将wav文件转写为同名的.lab文件")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f", "--filepath", default="./raw/lzy_zh", help="path of your model"
)
parser.add_argument("-l", "--language", default="ZH", help="language")
parser.add_argument("-w", "--workers", default="1", help="trans workers")
args = parser.parse_args()
transcribe_folder_parallel(args.filepath, args.language, int(args.workers))
print("转写结束!")