Skip to content

Commit

Permalink
Merge pull request s3prl#513 from s3prl/include-representative-entrie…
Browse files Browse the repository at this point in the history
…s-to-test

Include representative entries to test
  • Loading branch information
leo19941227 authored Oct 28, 2023
2 parents 68c05ac + 866ef59 commit 8d7e757
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 22 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ jobs:
build-and-test:

strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04]
python-version: [3.8]
Expand Down Expand Up @@ -61,7 +62,7 @@ jobs:

- name: Run tox for common upstream
run: |
tox -e common_upstream-audio${{ matrix.torchaudio-version }}
tox -e common_upstream-audio${{ matrix.torchaudio-version }} -- -k test_common_upstream
# Not all upstreams will be tested on CI since this will take too much time
# To test all upstreams, run tox locally with '-e all_upstream'
Expand Down
60 changes: 60 additions & 0 deletions .github/workflows/espnet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: CI

on:
# Trigger the workflow on push to main or any pull request
push:
branches:
- main
pull_request:

jobs:
build-and-test:

strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04]
python-version: [3.8]
torchaudio-version: [0.12.1, 0.13.1, 2.0.2, 2.1.0]
# espnet use 'torchaudio.models.hubert_pretrain_model' with 'feature_grad_mult' option
# which is available after torchaudio==0.12.0

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install Linux dependencies
run: |
sudo apt-get update
sudo apt-get install -y libsndfile1-dev sox git git-lfs
- name: Upgrade pip and wheel
run: pip3 install --upgrade pip wheel

- name: Install dependencies for tests
run: pip3 install -r requirements/dev.txt

# This can be very helpful for debugging
# The action can create a SSH server for you to connect. After you
# log into the machine hosted by GitHub, it becomes easy to debug
# why the CI fails on a specific machine.

# - name: Setup upterm session
# uses: lhotari/action-upterm@v1

- name: Run tox for common upstream
run: |
tox -e common_upstream-audio${{ matrix.torchaudio-version }}-espnet -- -k test_specific_upstream --upstream_names espnet_hubert_base_iter1
# Not all upstreams will be tested on CI since this will take too much time
# To test all upstreams, run tox locally with '-e all_upstream'

# - name: Run tox for all other functionalities
# run: |
# tox -e all_others-deps_all-audio${{ matrix.torchaudio-version }}
8 changes: 6 additions & 2 deletions s3prl/upstream/apc/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def apc_360hr(refresh=False, *args, **kwargs):
The apc standard model on 360hr
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs["ckpt"] = "https://www.dropbox.com/s/mcq82c0x62h9004/apc_default.ckpt?dl=1"
kwargs[
"ckpt"
] = "https://huggingface.co/leo19941227/apc_series/resolve/main/apc_360hr.ckpt"
return apc_url(refresh=refresh, *args, **kwargs)


Expand All @@ -53,5 +55,7 @@ def apc_960hr(refresh=False, *args, **kwargs):
The apc standard model on 960hr
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs["ckpt"] = "https://www.dropbox.com/s/mmfx3opdr4lz25n/apc_960hr.ckpt?dl=1"
kwargs[
"ckpt"
] = "https://huggingface.co/leo19941227/apc_series/resolve/main/apc_960hr.ckpt"
return apc_url(refresh=refresh, *args, **kwargs)
2 changes: 1 addition & 1 deletion s3prl/upstream/audio_albert/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ def audio_albert_logMelBase_T_share_AdamW_b32_1m_960hr_drop1(
"""
kwargs[
"ckpt"
] = "https://www.dropbox.com/s/3wgynxmod77ha1z/states-1000000.ckpt?dl=1"
] = "https://huggingface.co/s3prl/audio_albert/resolve/main/audio_albert_logMelBase_T_share_AdamW_b32_1m_960hr_drop1/states-1000000.ckpt"
return audio_albert_url(refresh=refresh, *args, **kwargs)
4 changes: 3 additions & 1 deletion s3prl/upstream/mockingjay/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def mockingjay_logMelLinearLarge_T_AdamW_b32_500k_360hr_drop1(
Total steps: 500k
Unlabled Speech: 360hr
"""
kwargs["ckpt"] = "https://www.dropbox.com/s/zwsfa6w2iy2cc68/states-500000.ckpt?dl=1"
kwargs[
"ckpt"
] = "https://huggingface.co/s3prl/mockingjay/resolve/main/mockingjay_logMelLinearLarge_T_AdamW_b32_500k_360hr_drop1/states-500000.ckpt"
return mockingjay_url(refresh=refresh, *args, **kwargs)


Expand Down
8 changes: 6 additions & 2 deletions s3prl/upstream/npc/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def npc_360hr(refresh=False, *args, **kwargs):
The npc standard model on 360hr
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs["ckpt"] = "https://www.dropbox.com/s/o4zpjz6xncbij8p/npc_default.ckpt?dl=1"
kwargs[
"ckpt"
] = "https://huggingface.co/leo19941227/apc_series/resolve/main/npc_360hr.ckpt"
return npc_url(refresh=refresh, *args, **kwargs)


Expand All @@ -53,5 +55,7 @@ def npc_960hr(refresh=False, *args, **kwargs):
The npc standard model on 960hr
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs["ckpt"] = "https://www.dropbox.com/s/7ep0v60ym136bpb/npc_960hr.ckpt?dl=1"
kwargs[
"ckpt"
] = "https://huggingface.co/leo19941227/apc_series/resolve/main/npc_960hr.ckpt"
return npc_url(refresh=refresh, *args, **kwargs)
2 changes: 1 addition & 1 deletion s3prl/upstream/tera/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def tera_logMelBase_T_F_M_AdamW_b32_1m_960hr_drop1(refresh=False, *args, **kwarg
"""
kwargs[
"ckpt"
] = "https://www.dropbox.com/s/xdoj9wdo87lztv1/states-1000000.ckpt?dl=1"
] = "https://huggingface.co/s3prl/tera/resolve/main/tera_logMelBase_T_F_M_AdamW_b32_1m_960hr_drop1/states-1000000.ckpt"
return tera_url(refresh=refresh, *args, **kwargs)


Expand Down
6 changes: 4 additions & 2 deletions s3prl/upstream/vq_apc/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def vq_apc_360hr(refresh=False, *args, **kwargs):
"""
kwargs[
"ckpt"
] = "https://www.dropbox.com/s/6auicz4ovl0nwlq/vq_apc_default.ckpt?dl=1"
] = "https://huggingface.co/leo19941227/apc_series/resolve/main/vq_apc_360hr.ckpt"
return vq_apc_url(refresh=refresh, *args, **kwargs)


Expand All @@ -34,5 +34,7 @@ def vq_apc_960hr(refresh=False, *args, **kwargs):
The vq-apc standard model on 960hr
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs["ckpt"] = "https://www.dropbox.com/s/xduhcr3y8c0qpc2/vq_apc_960hr.ckpt?dl=1"
kwargs[
"ckpt"
] = "https://huggingface.co/leo19941227/apc_series/resolve/main/vq_apc_960hr.ckpt"
return vq_apc_url(refresh=refresh, *args, **kwargs)
82 changes: 70 additions & 12 deletions test/test_upstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
SAMPLE_RATE = 16000
ATOL = 0.01
MAX_LENGTH_DIFF = 3
EXTRA_SHORT_SEC = 0.001
EXTRA_SHORT_SEC = 0.05
EXTRACTED_GT_DIR = Path(__file__).parent.parent / "sample_hidden_states"
S3PRL_HF_SAMPLE_HS = "https://huggingface.co/datasets/s3prl/sample_hidden_states"

# Expect the following directory structure:
#
Expand All @@ -49,7 +50,7 @@ def _prepare_sample_hidden_states():
logger.info("Downloading extracted sample hidden states...")
check_call("git lfs install".split(), cwd=tempdir, env=env)
check_call(
"git clone https://huggingface.co/datasets/s3prl/sample_hidden_states".split(),
f"git clone {S3PRL_HF_SAMPLE_HS}".split(),
cwd=tempdir,
env=env,
)
Expand All @@ -58,7 +59,11 @@ def _prepare_sample_hidden_states():
)
else:
logger.info(f"{EXTRACTED_GT_DIR} exists. Perform git pull...")
check_call("git pull".split(), cwd=EXTRACTED_GT_DIR, env=env)
check_call(
f"git pull {S3PRL_HF_SAMPLE_HS} main".split(),
cwd=EXTRACTED_GT_DIR,
env=env,
)

try:
lock_file.unlink()
Expand Down Expand Up @@ -180,25 +185,78 @@ def _filter_options(options: list):
"""


def _test_specific_upstream(name: str):
_compare_with_extracted(name)
_test_forward_backward(
name, min_secs=EXTRA_SHORT_SEC, max_secs=EXTRA_SHORT_SEC, n=1
)
_test_forward_backward(
name, min_secs=EXTRA_SHORT_SEC, max_secs=EXTRA_SHORT_SEC, n=2
)
_test_forward_backward(name, min_secs=EXTRA_SHORT_SEC, max_secs=1, n=3)


@pytest.mark.upstream
@pytest.mark.parametrize(
"name",
[
"apc",
"audio_albert",
"fbank",
"mel",
"modified_cpc",
"data2vec",
"decoar_layers",
"decoar2",
"distilhubert",
# "espnet_hubert_base_iter1", # espnet will be tested separately due to complex dependency
"hubert",
"lighthubert_base",
"mockingjay",
"npc",
"discretebert",
"tera",
"unispeech_sat_base",
"vq_apc",
"vq_wav2vec",
"wav2vec",
"wav2vec2",
"wavlm",
"hubert",
],
)
def test_common_upstream(name):
if "espnet" in name:
try:
import espnet
except:
logger.info("Skip ESPNet upstream test cases if espnet is not installed")
return

_prepare_sample_hidden_states()
_compare_with_extracted(name)
_test_forward_backward(
name, min_secs=EXTRA_SHORT_SEC, max_secs=EXTRA_SHORT_SEC, n=1
)
_test_forward_backward(
name, min_secs=EXTRA_SHORT_SEC, max_secs=EXTRA_SHORT_SEC, n=2
)
_test_forward_backward(name, min_secs=EXTRA_SHORT_SEC, max_secs=1, n=3)
_test_specific_upstream(name)


@pytest.mark.upstream
def test_specific_upstream(upstream_names: str):
_prepare_sample_hidden_states()
if upstream_names is not None:
options = upstream_names.split(",")

tracebacks = []
for name in options:
logger.info(f"Testing upstream: '{name}'")
try:
_test_specific_upstream(name)
except Exception as e:
logger.error(f"{name}\n{traceback.format_exc()}")
tb = traceback.format_exc()
tracebacks.append((name, tb))

if len(tracebacks) > 0:
for name, tb in tracebacks:
logger.error(f"Error in {name}:\n{tb}")
logger.error(f"All failed models:\n{[name for name, _ in tracebacks]}")
assert False


@pytest.mark.upstream
Expand Down
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ commands =
audio2.0.2: {envpython} -m pip install torchaudio==2.0.2
audio2.1.0: {envpython} -m pip install torchaudio==2.1.0

espnet: {envpython} -m pip install espnet>="202308"

# test import s3prl before installing dependencies for testing
{envpython} -c "import s3prl; from s3prl.nn import S3PRLUpstream;"

Expand Down

0 comments on commit 8d7e757

Please sign in to comment.