Skip to content

Commit

Permalink
Implement max line width and max line count, and make word highlighti…
Browse files Browse the repository at this point in the history
…ng optional (openai#1184)

* Add highlight_words, max_line_width, max_line_count

* Refactor subtitle generator

---------

Co-authored-by: Jong Wook Kim <[email protected]>
  • Loading branch information
ryanheise and jongwook authored Apr 11, 2023
1 parent 255887f commit 43940fc
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 42 deletions.
13 changes: 12 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ def cli():
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on

Expand Down Expand Up @@ -433,9 +436,17 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
writer(result, audio_path, writer_args)


if __name__ == "__main__":
Expand Down
135 changes: 94 additions & 41 deletions whisper/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
import re
import sys
import zlib
from typing import Callable, TextIO
from typing import Callable, Optional, TextIO

system_encoding = sys.getdefaultencoding()

Expand Down Expand Up @@ -73,24 +74,24 @@ class ResultWriter:
def __init__(self, output_dir: str):
self.output_dir = output_dir

def __call__(self, result: dict, audio_path: str):
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)

with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
self.write_result(result, file=f, options=options)

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError


class WriteTXT(ResultWriter):
extension: str = "txt"

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)

Expand All @@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str

def iterate_result(self, result: dict):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")

if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text

yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end

if last != segment_end:
yield last, segment_end, segment_text
else:
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None

def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
if len(subtitle) > 0:
yield subtitle

if "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text

yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text

def format_timestamp(self, seconds: float):
Expand All @@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False
decimal_marker: str = "."

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result):
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)


Expand All @@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True
decimal_marker: str = ","

def write_result(self, result: dict, file: TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)


Expand All @@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):

extension: str = "tsv"

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
Expand All @@ -180,11 +231,13 @@ def write_result(self, result: dict, file: TextIO):
class WriteJSON(ResultWriter):
extension: str = "json"

def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file)


def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
Expand All @@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]

def write_all(result: dict, file: TextIO):
def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
writer(result, file)
writer(result, file, options)

return write_all

Expand Down

0 comments on commit 43940fc

Please sign in to comment.