forked from pytorch/android-demo-app
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torchaudio based wav2vec2 with no model input length limit (pytorch#141)
* initial commit * Revert "initial commit" This reverts commit 5a65775. * main readme and helloworld/demo app readme updates * updated script to create torchaudio based wav2vec2 model with no recording length limit; android code update * README update * README update * updated script, build gradle and README for torch 1.9.0 and torchaudio 0.9.0
- Loading branch information
Showing
4 changed files
with
96 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,64 @@ | ||
import soundfile as sf | ||
import torch | ||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer | ||
from torch import Tensor | ||
from torch.utils.mobile_optimizer import optimize_for_mobile | ||
import torchaudio | ||
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model | ||
from transformers import Wav2Vec2ForCTC | ||
|
||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") | ||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | ||
model.eval() | ||
# Wav2vec2 model emits sequences of probability (logits) distributions over the characters | ||
# The following class adds steps to decode the transcript (best path) | ||
class SpeechRecognizer(torch.nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
self.labels = [ | ||
"<s>", "<pad>", "</s>", "<unk>", "|", "E", "T", "A", "O", "N", "I", "H", "S", | ||
"R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X", | ||
"J", "Q", "Z"] | ||
|
||
def forward(self, waveforms: Tensor) -> str: | ||
"""Given a single channel speech data, return transcription. | ||
Args: | ||
waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`. | ||
audio_input, _ = sf.read("scent_of_a_woman_future.wav") | ||
input_values = tokenizer(audio_input, return_tensors="pt").input_values | ||
print(input_values.shape) # input_values is of 65024 long, matched INPUT_SIZE defined in Android code | ||
Returns: | ||
str: The resulting transcript | ||
""" | ||
logits, _ = self.model(waveforms) # [batch, num_seq, num_label] | ||
best_path = torch.argmax(logits[0], dim=-1) # [num_seq,] | ||
prev = '' | ||
hypothesis = '' | ||
for i in best_path: | ||
char = self.labels[i] | ||
if char == prev: | ||
continue | ||
if char == '<s>': | ||
prev = '' | ||
continue | ||
hypothesis += char | ||
prev = char | ||
return hypothesis.replace('|', ' ') | ||
|
||
|
||
# Load Wav2Vec2 pretrained model from Hugging Face Hub | ||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | ||
# Convert the model to torchaudio format, which supports TorchScript. | ||
model = import_huggingface_model(model) | ||
# Remove weight normalization which is not supported by quantization. | ||
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() | ||
model = model.eval() | ||
# Attach decoder | ||
model = SpeechRecognizer(model) | ||
|
||
logits = model(input_values).logits | ||
predicted_ids = torch.argmax(logits, dim=-1) | ||
transcription = tokenizer.batch_decode(predicted_ids)[0] | ||
print(transcription) | ||
# Apply quantization / script / optimize for motbile | ||
quantized_model = torch.quantization.quantize_dynamic( | ||
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) | ||
scripted_model = torch.jit.script(quantized_model) | ||
optimized_model = optimize_for_mobile(scripted_model) | ||
|
||
traced_model = torch.jit.trace(model, input_values, strict=False) | ||
model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) | ||
traced_quantized_model = torch.jit.trace(model_dynamic_quantized, input_values, strict=False) | ||
# Sanity check | ||
waveform , _ = torchaudio.load('scent_of_a_woman_future.wav') | ||
print('Result:', optimized_model(waveform)) | ||
|
||
optimized_traced_quantized_model = optimize_for_mobile(traced_quantized_model) | ||
optimized_traced_quantized_model.save("app/src/main/assets/wav2vec_traced_quantized.pt") | ||
optimized_model.save("wav2vec2.pt") |