Skip to content

Commit

Permalink
allow test_transcribe to run on CPU when CUDA is not available
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Jan 17, 2023
1 parent 493dfff commit b1d213c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
pytorch-version: 1.10.2
steps:
- uses: conda-incubator/setup-miniconda@v2
- run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
Expand Down
6 changes: 4 additions & 2 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os

import pytest
import torch

import whisper


@pytest.mark.parametrize('model_name', whisper.available_models())
@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str):
model = whisper.load_model(model_name).cuda()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")

language = "en" if model_name.endswith(".en") else None
Expand Down

0 comments on commit b1d213c

Please sign in to comment.