Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yxdu #136

Closed
wants to merge 10 commits into from
Closed

Yxdu #136

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/st_covost2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ST_covost2

## Download Model
We only train the q-former projector in this recipe.
Encoder | Projector | LLM
|---|---|---
[whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | [q-former](https://huggingface.co/yxdu/cotst) | [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B)
```
git lfs clone https://huggingface.co/openai/whisper-large-v3
git lfs clone https://huggingface.co/yxdu/cotst
git lfs clone https://huggingface.co/Qwen/Qwen2-7B
```


## Data
You need to download this dataset.
```
(https://github.com/facebookresearch/covost)
```



## Data preparation
You need to prepare the data jsonl in this format.
You can find the test jsonl in "test_st.jsonl"
```
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|en|>", "gt": "\"She'll be all right.\"", "source": "covost_en"}
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|de|>", "gt": "\"She'll be all right.\"<|de|>Sie wird schon in Ordnung sein.", "source": "covost_ende"}
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|ja|>", "gt": "\"She'll be all right.\"<|ja|>彼女は大丈夫だろう。", "source": "covost_enja"}
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|zh|>", "gt": "\"She'll be all right.\"<|zh|>她会没事的。", "source": "covost_enzh"}
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|de|>", "gt": "\"She'll be all right.\"<|de|>Sie wird schon in Ordnung sein.", "source": "covost_enende"}
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|ja|>", "gt": "\"She'll be all right.\"<|ja|>彼女は大丈夫だろう。", "source": "covost_enenja"}
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|zh|>", "gt": "\"She'll be all right.\"<|zh|>她会没事的。", "source": "covost_enenzh"}
```
## Train Stage
Here, we have designed a four-step training process, where each training session uses the checkpoint obtained from the previous training session.
```
#In this step, we perform ASR pretraining to acquire speech recognition capabilities.
bash asr_pretrain.sh

#In this phase, we conduct multimodal machine translation training to enhance the final performance.
bash mmt.sh

#monolingual SRT training.
bash srt.sh

#multilingual multitask training.
bash zsrt.sh
```


## Infer Stage
You can try our pre-trained model.

```
bash infer.sh
```

## Citation
You can refer to the paper for more results.
```
@article{ma2024embarrassingly,
title={An Embarrassingly Simple Approach for LLM with Strong ASR Capacity},
author={Ma, Ziyang and Yang, Guanrou and Yang, Yifan and Gao, Zhifu and Wang, Jiaming and Du, Zhihao and Yu, Fan and Chen, Qian and Zheng, Siqi and Zhang, Shiliang and others},
journal={arXiv preprint arXiv:2402.08846},
year={2024}
}
```
135 changes: 135 additions & 0 deletions examples/st_covost2/asr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from dataclasses import dataclass, field
from typing import Optional, List
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

@dataclass
class ModelConfig:
file: str = "examples/st_covost2/model/slam_model_st.py"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 3584
encoder_path_hf: Optional[str] = None
encoder_name: Optional[str] = None
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_dim: int = 1280
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})
ckpt_path: Optional[str] = None
query_len: Optional[str] = None


@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training:int = 4
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
}) #
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.01
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:int = 1

use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = False

@dataclass
class DataConfig:
dataset: str = "st_dataset"
file: str = "examples/st_covost2/dataset/st_dataset.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split:str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode:bool = False
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
bf16:bool = True
source: Optional[str] = None

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP".
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "test_wandb"
wandb_entity_name: str = "SLAM"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "./test.log"
log_interval: int = 50
decode_log: str = "./test.log"
19 changes: 19 additions & 0 deletions examples/st_covost2/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
4 changes: 4 additions & 0 deletions examples/st_covost2/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt: "<en>"
Loading
Loading