Skip to content

Commit

Permalink
Use Conv1d for TDNN (huggingface#25728)
Browse files Browse the repository at this point in the history
* use conv for tdnn

* run make fixup

* update TDNN

* add PEFT LoRA check

* propagate tdnn warnings to others

* add missing imports

* update TDNN in wav2vec2_bert

* add missing imports
  • Loading branch information
gau-nernst authored Jan 30, 2024
1 parent 866253f commit 5c8d941
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 56 deletions.
31 changes: 21 additions & 10 deletions src/transformers/models/data2vec/modeling_data2vec_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
)
from .configuration_data2vec_audio import Data2VecAudioConfig


Expand Down Expand Up @@ -1342,16 +1348,21 @@ def __init__(self, config, layer_id=0):
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()

def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_peft_available():
from peft.tuners.lora import LoraLayer

if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
"You should exclude TDNNLayer from LoRA's target modules.",
)

# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
hidden_states = hidden_states.transpose(1, 2)
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)

hidden_states = self.activation(hidden_states)
return hidden_states
Expand Down
24 changes: 15 additions & 9 deletions src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1796,16 +1797,21 @@ def __init__(self, config, layer_id=0):
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()

def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_peft_available():
from peft.tuners.lora import LoraLayer

if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
"You should exclude TDNNLayer from LoRA's target modules.",
)

# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
hidden_states = hidden_states.transpose(1, 2)
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)

hidden_states = self.activation(hidden_states)
return hidden_states
Expand Down
24 changes: 15 additions & 9 deletions src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
cached_file,
is_peft_available,
is_safetensors_available,
logging,
replace_return_docstrings,
Expand Down Expand Up @@ -2287,16 +2288,21 @@ def __init__(self, config, layer_id=0):
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()

def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_peft_available():
from peft.tuners.lora import LoraLayer

if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
"You should exclude TDNNLayer from LoRA's target modules.",
)

# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
hidden_states = hidden_states.transpose(1, 2)
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)

hidden_states = self.activation(hidden_states)
return hidden_states
Expand Down
25 changes: 16 additions & 9 deletions src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" PyTorch Wav2Vec2-BERT model."""

import math
import warnings
from typing import Optional, Tuple, Union

import numpy as np
Expand All @@ -39,6 +40,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
)
from .configuration_wav2vec2_bert import Wav2Vec2BertConfig
Expand Down Expand Up @@ -1516,16 +1518,21 @@ def __init__(self, config, layer_id=0):
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()

def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_peft_available():
from peft.tuners.lora import LoraLayer

if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
"You should exclude TDNNLayer from LoRA's target modules.",
)

# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
hidden_states = hidden_states.transpose(1, 2)
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)

hidden_states = self.activation(hidden_states)
return hidden_states
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" PyTorch Wav2Vec2-Conformer model."""

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand All @@ -40,6 +41,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1948,16 +1950,21 @@ def __init__(self, config, layer_id=0):
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()

def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_peft_available():
from peft.tuners.lora import LoraLayer

if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
"You should exclude TDNNLayer from LoRA's target modules.",
)

# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
hidden_states = hidden_states.transpose(1, 2)
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)

hidden_states = self.activation(hidden_states)
return hidden_states
Expand Down
31 changes: 21 additions & 10 deletions src/transformers/models/wavlm/modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_peft_available,
logging,
)
from .configuration_wavlm import WavLMConfig


Expand Down Expand Up @@ -1674,16 +1680,21 @@ def __init__(self, config, layer_id=0):
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
self.activation = nn.ReLU()

def forward(self, hidden_states):
hidden_states = hidden_states.unsqueeze(1)
hidden_states = nn.functional.unfold(
hidden_states,
(self.kernel_size, self.in_conv_dim),
stride=(1, self.in_conv_dim),
dilation=(self.dilation, 1),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if is_peft_available():
from peft.tuners.lora import LoraLayer

if isinstance(self.kernel, LoraLayer):
warnings.warn(
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
"You should exclude TDNNLayer from LoRA's target modules.",
)

# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
hidden_states = hidden_states.transpose(1, 2)
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.kernel(hidden_states)

hidden_states = self.activation(hidden_states)
return hidden_states
Expand Down

0 comments on commit 5c8d941

Please sign in to comment.