Skip to content

Commit

Permalink
torchaudio based wav2vec2 with no model input length limit (pytorch#141)
Browse files Browse the repository at this point in the history
* initial commit

* Revert "initial commit"

This reverts commit 5a65775.

* main readme and helloworld/demo app readme updates

* updated script to create torchaudio based wav2vec2 model with no recording length limit; android code update

* README update

* README update

* updated script, build gradle and README for torch 1.9.0 and torchaudio 0.9.0
  • Loading branch information
jeffxtang authored Jun 16, 2021
1 parent 5f3057a commit 367d2d9
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 67 deletions.
38 changes: 28 additions & 10 deletions SpeechRecognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,54 @@

Facebook AI's [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) is one of the leading models in speech recognition. It is also available in the [Huggingface Transformers](https://github.com/huggingface/transformers) library, which is also used in another PyTorch Android demo app for [Question Answering](https://github.com/pytorch/android-demo-app/tree/master/QuestionAnswering).

In this demo app, we'll show how to quantize and convert the wav2vec2 model to TorchScript and how to use the scripted model on an Android demo app to perform speech recognition.
In this demo app, we'll show how to quantize, trace, and optimize the wav2vec2 model, powered by the newly released torchaudio 0.9.0, and how to use the converted model on an Android demo app to perform speech recognition.

## Prerequisites

* PyTorch 1.8 (Optional)
* PyTorch 1.9.0 and torchaudio 0.9.0 (Optional)
* Python 3.8 (Optional)
* Android Pytorch library 1.8
* Android PyTorch library 1.9.0
* Android Studio 4.0.1 or later

## Quick Start

### 1. Prepare the Model
### 1. Get the Repo

Simply run the commands below:

First, run the following commands on a Terminal:
```
git clone https://github.com/pytorch/android-demo-app
cd android-demo-app/SpeechRecognition
```

If you don't have PyTorch installed or want to have a quick try of the demo app, you can download the quantized scripted wav2vec2 model compressed in a zip file [here](https://drive.google.com/file/d/1wW6qs-OR76usbBXvEyqUH_mRqa0ShMfT/view?usp=sharing), then unzip it to the assets folder, and continue to Step 2.
If you don't have PyTorch 1.9.0 and torchaudio 0.9.0 installed or want to have a quick try of the demo app, you can download the quantized scripted wav2vec2 model file [here](https://drive.google.com/file/d/1RcCy3K3gDVN2Nun5IIdDbpIDbrKD-XVw/view?usp=sharing), then drag and drop it to the `app/src/main/assets` folder inside `android-demo-app/SpeechRecognition`, and continue to Step 3.

### 2. Prepare the Model

To install PyTorch 1.9.0, torchaudio 0.9.0 and the Hugging Face transformers, you can do something like this:

```
conda create -n wav2vec2 python=3.8.5
conda activate wav2vec2
pip install torch torchaudio
pip install transformers
```

Be aware that the downloadable model file was created with PyTorch 1.8.0, matching the PyTorch Android library 1.8.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.8.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
Now with PyTorch 1.9.0 and torchaudio 0.9.0 installed, run the following commands on a Terminal:

With PyTorch 1.8 installed, first install the Huggingface `transformers` by running `pip install transformers` (the version that has been tested is 4.3.2), then run `python create_wav2vec2.py`, which creates `wav2vec_traced_quantized.pt` in the `app/src/main/assets` folder. [Dynamic quantization](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html) is used to quantize the model to reduce its size.
```
python create_wav2vec2.py
```
This will create the model file `wav2vec2.pt`. Copy it to the Android app:
```
Note that the sample scent_of_a_woman_future.wav file used to trace the model is about 6 second long, so 6 second is the limit of the recorded audio for speech recognition in the demo app. If your speech is less than 6 seconds, padding is applied in the code to make the model work correctly.
mkdir -p app/src/main/assets
cp wav2vec2.pt app/src/main/assets
```

### 2. Build and run with Android Studio

Start Android Studio, open the project located in `android-demo-app/SpeechRecognition`, build and run the app on an Android device. After the app runs, tap the Start button and start saying something; after 6 seconds, the model will infer to recognize your speech. Only basic decoding of the recognition result from an array of floating numbers of logits to a list of tokens is provided in this demo app, but it is easy to see, without further post-processing, whether the model can recognize your utterances. Some example recognition results are:
Start Android Studio, open the project located in `android-demo-app/SpeechRecognition`, build and run the app on an Android device. After the app runs, tap the Start button and start saying something; after 12 seconds (you can change `private final static int AUDIO_LEN_IN_SECOND = 12;` in `MainActivity.java` for a shorter or longer recording length), the model will infer to recognize your speech. Some example recognition results are:

![](screenshot1.png)
![](screenshot2.png)
Expand Down
3 changes: 1 addition & 2 deletions SpeechRecognition/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ android {
}

dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'

implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android:1.9.0'

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,10 @@ public class MainActivity extends AppCompatActivity implements Runnable {
private TextView mTextView;
private Button mButton;

private final static String[] tokens = {"<s>", "<pad>", "</s>", "<unk>", "|", "E", "T", "A", "O", "N", "I", "H", "S", "R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X", "J", "Q", "Z"};
private final static int INPUT_SIZE = 65024;
private final static int AUDIO_LEN_LIMIT = 6;

private final static int REQUEST_RECORD_AUDIO = 13;
private final static int AUDIO_LEN_IN_SECOND = 12;
private final static int SAMPLE_RATE = 16000;
private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_LIMIT;
private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;

private final static String LOG_TAG = MainActivity.class.getSimpleName();

Expand All @@ -55,7 +52,7 @@ public void run() {

MainActivity.this.runOnUiThread(
() -> {
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_LIMIT - mStart));
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND - mStart));
mStart += 1;
});
}
Expand Down Expand Up @@ -90,7 +87,7 @@ protected void onCreate(Bundle savedInstanceState) {

mButton.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_LIMIT));
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND));
mButton.setEnabled(false);

Thread thread = new Thread(MainActivity.this);
Expand Down Expand Up @@ -197,46 +194,21 @@ public void run() {
private String recognize(float[] floatInputBuffer) {
if (mModuleEncoder == null) {
final String moduleFileAbsoluteFilePath = new File(
assetFilePath(this, "wav2vec_traced_quantized.pt")).getAbsolutePath();
assetFilePath(this, "wav2vec2.pt")).getAbsolutePath();
mModuleEncoder = Module.load(moduleFileAbsoluteFilePath);
}

double wav2vecinput[] = new double[INPUT_SIZE];
for (int n = 0; n < INPUT_SIZE; n++)
double wav2vecinput[] = new double[RECORDING_LENGTH];
for (int n = 0; n < RECORDING_LENGTH; n++)
wav2vecinput[n] = floatInputBuffer[n];

FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(INPUT_SIZE);
FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(RECORDING_LENGTH);
for (double val : wav2vecinput)
inTensorBuffer.put((float)val);

Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, INPUT_SIZE});

final Map<String, IValue> map = mModuleEncoder.forward(IValue.from(inTensor)).toDictStringKey();
final Tensor logitsTensor = map.get("logits").toTensor();
final float[] values = logitsTensor.getDataAsFloatArray();
Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, RECORDING_LENGTH});
final String result = mModuleEncoder.forward(IValue.from(inTensor)).toStr();

String result = "";
float row[] = new float[tokens.length];
for (int i = 0; i < values.length; i++) {
row[i % tokens.length] = values[i];
if (i > 0 && i % tokens.length == 0) {
int tid = argmax(row);
if (tid > 4) result = String.format("%s%s", result, tokens[tid]);
else if (tid == 4) result = String.format("%s ", result);
}
}
return result;
}

private int argmax(float[] array) {
int maxIdx = 0;
double maxVal = -Double.MAX_VALUE;
for (int j=0; j<array.length; j++) {
if (array[j] > maxVal) {
maxVal = array[j];
maxIdx = j;
}
}
return maxIdx;
}
}
74 changes: 57 additions & 17 deletions SpeechRecognition/create_wav2vec2.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,64 @@
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from torch import Tensor
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from transformers import Wav2Vec2ForCTC

tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.eval()
# Wav2vec2 model emits sequences of probability (logits) distributions over the characters
# The following class adds steps to decode the transcript (best path)
class SpeechRecognizer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.labels = [
"<s>", "<pad>", "</s>", "<unk>", "|", "E", "T", "A", "O", "N", "I", "H", "S",
"R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X",
"J", "Q", "Z"]

def forward(self, waveforms: Tensor) -> str:
"""Given a single channel speech data, return transcription.
Args:
waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`.
audio_input, _ = sf.read("scent_of_a_woman_future.wav")
input_values = tokenizer(audio_input, return_tensors="pt").input_values
print(input_values.shape) # input_values is of 65024 long, matched INPUT_SIZE defined in Android code
Returns:
str: The resulting transcript
"""
logits, _ = self.model(waveforms) # [batch, num_seq, num_label]
best_path = torch.argmax(logits[0], dim=-1) # [num_seq,]
prev = ''
hypothesis = ''
for i in best_path:
char = self.labels[i]
if char == prev:
continue
if char == '<s>':
prev = ''
continue
hypothesis += char
prev = char
return hypothesis.replace('|', ' ')


# Load Wav2Vec2 pretrained model from Hugging Face Hub
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Convert the model to torchaudio format, which supports TorchScript.
model = import_huggingface_model(model)
# Remove weight normalization which is not supported by quantization.
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
model = model.eval()
# Attach decoder
model = SpeechRecognizer(model)

logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
print(transcription)
# Apply quantization / script / optimize for motbile
quantized_model = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_model = torch.jit.script(quantized_model)
optimized_model = optimize_for_mobile(scripted_model)

traced_model = torch.jit.trace(model, input_values, strict=False)
model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
traced_quantized_model = torch.jit.trace(model_dynamic_quantized, input_values, strict=False)
# Sanity check
waveform , _ = torchaudio.load('scent_of_a_woman_future.wav')
print('Result:', optimized_model(waveform))

optimized_traced_quantized_model = optimize_for_mobile(traced_quantized_model)
optimized_traced_quantized_model.save("app/src/main/assets/wav2vec_traced_quantized.pt")
optimized_model.save("wav2vec2.pt")

0 comments on commit 367d2d9

Please sign in to comment.