Skip to content

Commit

Permalink
apply formatting with black (openai#1038)
Browse files Browse the repository at this point in the history
* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
  • Loading branch information
jongwook authored Mar 6, 2023
1 parent 500d0fe commit b80bcf6
Show file tree
Hide file tree
Showing 21 changed files with 533 additions and 227 deletions.
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
per-file-ignores =
*/__init__.py: F401

3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ jobs:
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"]
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
- run: isort --check --diff .
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[tool.black]

[tool.isort]
profile = "black"
include_trailing_comma = true
line_length = 88
multi_line_output = 3

12 changes: 8 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys

import pkg_resources
from setuptools import setup, find_packages
from setuptools import find_packages, setup


def read_version(fname="whisper/version.py"):
Expand All @@ -16,7 +16,10 @@ def read_version(fname="whisper/version.py"):
try:
import re
import subprocess
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]

version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
Expand All @@ -38,7 +41,8 @@ def read_version(fname="whisper/version.py"):
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=requirements + [
install_requires=requirements
+ [
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
Expand All @@ -48,5 +52,5 @@ def read_version(fname="whisper/version.py"):
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest", "scipy"]},
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
)
2 changes: 1 addition & 1 deletion tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE
from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram


def test_audio():
Expand Down
5 changes: 4 additions & 1 deletion tests/test_normalizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest

from whisper.normalizers import EnglishTextNormalizer
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer
from whisper.normalizers.english import (
EnglishNumberNormalizer,
EnglishSpellingNormalizer,
)


@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])
Expand Down
21 changes: 15 additions & 6 deletions tests/test_timing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import pytest
import numpy as np
import pytest
import scipy.ndimage
import torch

from whisper.timing import dtw_cpu, dtw_cuda, median_filter


sizes = [
(10, 20), (32, 16), (123, 1500), (234, 189),
(10, 20),
(32, 16),
(123, 1500),
(234, 189),
]
shapes = [
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
(10,),
(1, 15),
(4, 5, 345),
(6, 12, 240, 512),
]


Expand Down Expand Up @@ -68,8 +73,12 @@ def test_median_filter(shape):

# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
padded_x = np.pad(
x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
)
scipy_filtered = scipy.ndimage.median_filter(
padded_x, [1] * (x.ndim - 1) + [filter_width]
)
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]

assert np.allclose(filtered, scipy_filtered)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")

language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"

transcription = result["text"].lower()
Expand Down
50 changes: 30 additions & 20 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__


_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
Expand All @@ -41,12 +40,11 @@
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
}



def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)

Expand All @@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
Expand All @@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes if in_memory else download_target

Expand All @@ -86,7 +94,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys())


def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
Expand All @@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(
os.getenv(
"XDG_CACHE_HOME",
os.path.join(
os.path.expanduser("~"), ".cache"
)
),
"whisper"
)
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
Expand All @@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)

with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file

Expand Down
1 change: 0 additions & 1 deletion whisper/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .transcribe import cli


cli()
20 changes: 14 additions & 6 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
Expand Down Expand Up @@ -59,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
Expand Down Expand Up @@ -89,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
):
"""
Compute the log-Mel spectrogram of
Expand Down
Loading

0 comments on commit b80bcf6

Please sign in to comment.