forked from X-LANCE/SLAM-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request X-LANCE#67 from ddlBoJack/main
sync
- Loading branch information
Showing
33 changed files
with
6,125 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/bin/bash | ||
#export PYTHONPATH=/root/whisper:$PYTHONPATH | ||
export PYTHONPATH=/root/fairseq:$PYTHONPATH | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export TOKENIZERS_PARALLELISM=false | ||
# export CUDA_LAUNCH_BLOCKING=1 | ||
|
||
run_dir=/root/SLAM-LLM | ||
cd $run_dir | ||
code_dir=examples/asr_librispeech | ||
|
||
speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/hubert_ckpt/hubert_xtralarge_ll60k_finetune_ls960.pt | ||
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 | ||
|
||
output_dir=/nfs/yangguanrou.ygr/experiments_hubert/vicuna-7b-v1.5-hubert_xtralarge_ll60k_finetune_ls960 | ||
ckpt_path=$output_dir/asr_epoch_1_step_1000 | ||
split=librispeech_test_clean | ||
val_data_path=/nfs/maziyang.mzy/data/librispeech/${split}.jsonl | ||
decode_log=$ckpt_path/decode_${split}_beam4 | ||
|
||
# -m debugpy --listen 5678 --wait-for-client | ||
python $code_dir/inference_asr_batch.py \ | ||
--config-path "conf" \ | ||
--config-name "prompt.yaml" \ | ||
hydra.run.dir=$ckpt_path \ | ||
++model_config.llm_name="vicuna-7b-v1.5" \ | ||
++model_config.llm_path=$llm_path \ | ||
++model_config.llm_dim=4096 \ | ||
++model_config.encoder_name=hubert \ | ||
++model_config.normalize=true \ | ||
++dataset_config.normalize=true \ | ||
++model_config.encoder_projector_ds_rate=5 \ | ||
++model_config.encoder_path=$speech_encoder_path \ | ||
++model_config.encoder_dim=1280 \ | ||
++model_config.encoder_type=finetune \ | ||
++model_config.encoder_projector=linear \ | ||
++dataset_config.dataset=speech_dataset \ | ||
++dataset_config.val_data_path=$val_data_path \ | ||
++dataset_config.input_type=raw \ | ||
++dataset_config.inference_mode=true \ | ||
++dataset_config.prompt="Transcribe speech to text. " \ | ||
++train_config.model_name=asr \ | ||
++train_config.freeze_encoder=true \ | ||
++train_config.freeze_llm=true \ | ||
++train_config.batching_strategy=custom \ | ||
++train_config.num_epochs=1 \ | ||
++train_config.val_batch_size=1 \ | ||
++train_config.num_workers_dataloader=0 \ | ||
++train_config.output_dir=$output_dir \ | ||
++decode_log=$decode_log \ | ||
++ckpt_path=$ckpt_path/model.pt \ | ||
# ++peft_ckpt=$ckpt_path \ | ||
# ++train_config.use_peft=true \ | ||
# ++train_config.peft_config.r=32 \ | ||
# ++dataset_config.normalize=true \ | ||
# ++model_config.encoder_projector=q-former \ | ||
# ++dataset_config.fix_length_audio=64 \ | ||
|
||
python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc | ||
python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc | ||
python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer |
76 changes: 76 additions & 0 deletions
76
examples/asr_librispeech/scripts/finetune_hubert_xtralarge_linear_vicuna_7b.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#!/bin/bash | ||
# export PYTHONPATH=/root/whisper:$PYTHONPATH | ||
export PYTHONPATH=/root/fairseq:$PYTHONPATH | ||
export CUDA_VISIBLE_DEVICES=2,3 | ||
export TOKENIZERS_PARALLELISM=false | ||
# export CUDA_LAUNCH_BLOCKING=1 | ||
export OMP_NUM_THREADS=1 | ||
|
||
# debug setting for multiple gpus | ||
# export NCCL_DEBUG=INFO | ||
# export NCCL_DEBUG_SUBSYS=ALL | ||
# export TORCH_DISTRIBUTED_DEBUG=INFO | ||
|
||
run_dir=/root/SLAM-LLM | ||
cd $run_dir | ||
code_dir=examples/asr_librispeech | ||
|
||
speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/hubert_ckpt/hubert_xtralarge_ll60k_finetune_ls960.pt | ||
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 | ||
train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl | ||
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl | ||
|
||
output_dir=/root/tmp/vicuna-7b-v1.5-librispeech-linear-steplrwarmupkeep1e-4-hubert-xtralarge-$(date +"%Y%m%d") | ||
|
||
hydra_args=" | ||
hydra.run.dir=$output_dir \ | ||
++model_config.llm_name=vicuna-7b-v1.5 \ | ||
++model_config.llm_path=$llm_path \ | ||
++model_config.llm_dim=4096 \ | ||
++model_config.encoder_name=hubert \ | ||
++model_config.normalize=true \ | ||
++dataset_config.normalize=true \ | ||
++model_config.encoder_projector_ds_rate=5 \ | ||
++model_config.encoder_path=$speech_encoder_path \ | ||
++model_config.encoder_dim=1280 \ | ||
++model_config.encoder_type=finetune \ | ||
++model_config.encoder_projector=linear \ | ||
++dataset_config.dataset=speech_dataset \ | ||
++dataset_config.train_data_path=$train_data_path \ | ||
++dataset_config.val_data_path=$val_data_path \ | ||
++dataset_config.input_type=raw \ | ||
++train_config.model_name=asr \ | ||
++train_config.num_epochs=3 \ | ||
++train_config.freeze_encoder=true \ | ||
++train_config.freeze_llm=true \ | ||
++train_config.batching_strategy=custom \ | ||
++train_config.warmup_steps=1000 \ | ||
++train_config.total_steps=100000 \ | ||
++train_config.lr=1e-4 \ | ||
++train_config.validation_interval=2000 \ | ||
++train_config.batch_size_training=6 \ | ||
++train_config.val_batch_size=6 \ | ||
++train_config.num_workers_dataloader=0 \ | ||
++train_config.output_dir=$output_dir \ | ||
++metric=acc \ | ||
" | ||
|
||
# -m debugpy --listen 5678 --wait-for-client | ||
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then | ||
python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_asr.py \ | ||
--config-path "conf" \ | ||
--config-name "prompt.yaml" \ | ||
$hydra_args | ||
else | ||
torchrun \ | ||
--nnodes 1 \ | ||
--nproc_per_node 2 \ | ||
--master_port=29503 \ | ||
$code_dir/finetune_asr.py \ | ||
--config-path "conf" \ | ||
--config-name "prompt.yaml" \ | ||
++train_config.enable_fsdp=false \ | ||
++train_config.enable_ddp=true \ | ||
++train_config.use_fp16=true \ | ||
$hydra_args | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# VSR_LRS3 | ||
|
||
## Performance and checkpoints | ||
We only train the linear projector in this recipe. | ||
Encoder | Projector | LLM | test | ||
|---|---|---|---| | ||
[AV-HuBERT Large + Self-Training](https://dl.fbaipublicfiles.com/avhubert/model/lrs3_vox/vsr/self_large_vox_433h.pt) | [Linear](https://drive.google.com/file/d/1DNfJgyeLx9xet4DT5xZXyx8ZOcNoawL8/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 29.47 | ||
|
||
|
||
## Data preparation | ||
Follow the steps in [preparation](https://github.com/facebookresearch/av_hubert/tree/main/avhubert/preparation) of av_hubert to pre-process LRS3 dataset | ||
|
||
## Environment | ||
Use the specific fairseq version of [av_hubert](https://github.com/facebookresearch/av_hubert), which is compatible with hydra-core versions below 1.0.7 and omegaconf versions below 2.0.6. | ||
|
||
|
||
## Decode with checkpoints | ||
``` | ||
bash decode_avhubert_vo_vicuna_7b.sh | ||
``` | ||
Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path` and `decode_log` in the script when you run the shell script. | ||
|
||
## Train a new model | ||
|
||
### Use the visual part of AV-HuBERT Large as the encoder | ||
``` | ||
bash finetune_avhubert_vo_vicuna_7b.sh | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
prompt: "Transcribe the silent speech in this video to text by lip-reading the speaker's clear and visible lip movements." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from slam_llm.pipeline.finetune import main as train | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from vsr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
|
||
@hydra.main(config_name=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from slam_llm.pipeline.inference_batch import main as inference | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from typing import Optional | ||
from vsr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
decode_log: str = field( | ||
default="output/decode_log", | ||
metadata={"help": "The prefix for the decode output"}, | ||
) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
peft_ckpt: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The path to peft checkpoint, should be a directory including adapter_config.json" | ||
}, | ||
) | ||
|
||
|
||
@hydra.main(config_name=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
# kwargs = to_plain_list(cfg) | ||
log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if cfg.get("debug", False): | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
inference(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
Oops, something went wrong.