diff --git a/SpeechRecognition/README.md b/SpeechRecognition/README.md
index ac6aad7d..101ce7d8 100644
--- a/SpeechRecognition/README.md
+++ b/SpeechRecognition/README.md
@@ -4,36 +4,54 @@
Facebook AI's [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/master/examples/wav2vec) is one of the leading models in speech recognition. It is also available in the [Huggingface Transformers](https://github.com/huggingface/transformers) library, which is also used in another PyTorch Android demo app for [Question Answering](https://github.com/pytorch/android-demo-app/tree/master/QuestionAnswering).
-In this demo app, we'll show how to quantize and convert the wav2vec2 model to TorchScript and how to use the scripted model on an Android demo app to perform speech recognition.
+In this demo app, we'll show how to quantize, trace, and optimize the wav2vec2 model, powered by the newly released torchaudio 0.9.0, and how to use the converted model on an Android demo app to perform speech recognition.
## Prerequisites
-* PyTorch 1.8 (Optional)
+* PyTorch 1.9.0 and torchaudio 0.9.0 (Optional)
* Python 3.8 (Optional)
-* Android Pytorch library 1.8
+* Android PyTorch library 1.9.0
* Android Studio 4.0.1 or later
## Quick Start
-### 1. Prepare the Model
+### 1. Get the Repo
+
+Simply run the commands below:
-First, run the following commands on a Terminal:
```
git clone https://github.com/pytorch/android-demo-app
cd android-demo-app/SpeechRecognition
```
-If you don't have PyTorch installed or want to have a quick try of the demo app, you can download the quantized scripted wav2vec2 model compressed in a zip file [here](https://drive.google.com/file/d/1wW6qs-OR76usbBXvEyqUH_mRqa0ShMfT/view?usp=sharing), then unzip it to the assets folder, and continue to Step 2.
+If you don't have PyTorch 1.9.0 and torchaudio 0.9.0 installed or want to have a quick try of the demo app, you can download the quantized scripted wav2vec2 model file [here](https://drive.google.com/file/d/1RcCy3K3gDVN2Nun5IIdDbpIDbrKD-XVw/view?usp=sharing), then drag and drop it to the `app/src/main/assets` folder inside `android-demo-app/SpeechRecognition`, and continue to Step 3.
+
+### 2. Prepare the Model
+
+To install PyTorch 1.9.0, torchaudio 0.9.0 and the Hugging Face transformers, you can do something like this:
+
+```
+conda create -n wav2vec2 python=3.8.5
+conda activate wav2vec2
+pip install torch torchaudio
+pip install transformers
+```
-Be aware that the downloadable model file was created with PyTorch 1.8.0, matching the PyTorch Android library 1.8.0 specified in the project's `build.gradle` file as `implementation 'org.pytorch:pytorch_android:1.8.0'`. If you use a different version of PyTorch to create your model by following the instructions below, make sure you specify the same PyTorch Android library version in the `build.gradle` file to avoid possible errors caused by the version mismatch. Furthermore, if you want to use the latest PyTorch master code to create the model, follow the steps at [Building PyTorch Android from Source](https://pytorch.org/mobile/android/#building-pytorch-android-from-source) and [Using the PyTorch Android Libraries Built](https://pytorch.org/mobile/android/#using-the-pytorch-android-libraries-built-from-source-or-nightly) on how to use the model in Android.
+Now with PyTorch 1.9.0 and torchaudio 0.9.0 installed, run the following commands on a Terminal:
-With PyTorch 1.8 installed, first install the Huggingface `transformers` by running `pip install transformers` (the version that has been tested is 4.3.2), then run `python create_wav2vec2.py`, which creates `wav2vec_traced_quantized.pt` in the `app/src/main/assets` folder. [Dynamic quantization](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html) is used to quantize the model to reduce its size.
+```
+python create_wav2vec2.py
+```
+This will create the model file `wav2vec2.pt`. Copy it to the Android app:
+```
-Note that the sample scent_of_a_woman_future.wav file used to trace the model is about 6 second long, so 6 second is the limit of the recorded audio for speech recognition in the demo app. If your speech is less than 6 seconds, padding is applied in the code to make the model work correctly.
+mkdir -p app/src/main/assets
+cp wav2vec2.pt app/src/main/assets
+```
### 2. Build and run with Android Studio
-Start Android Studio, open the project located in `android-demo-app/SpeechRecognition`, build and run the app on an Android device. After the app runs, tap the Start button and start saying something; after 6 seconds, the model will infer to recognize your speech. Only basic decoding of the recognition result from an array of floating numbers of logits to a list of tokens is provided in this demo app, but it is easy to see, without further post-processing, whether the model can recognize your utterances. Some example recognition results are:
+Start Android Studio, open the project located in `android-demo-app/SpeechRecognition`, build and run the app on an Android device. After the app runs, tap the Start button and start saying something; after 12 seconds (you can change `private final static int AUDIO_LEN_IN_SECOND = 12;` in `MainActivity.java` for a shorter or longer recording length), the model will infer to recognize your speech. Some example recognition results are:
![](screenshot1.png)
![](screenshot2.png)
diff --git a/SpeechRecognition/app/build.gradle b/SpeechRecognition/app/build.gradle
index ce2b173b..61afe76e 100644
--- a/SpeechRecognition/app/build.gradle
+++ b/SpeechRecognition/app/build.gradle
@@ -34,13 +34,12 @@ android {
}
dependencies {
- implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
- implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
+ implementation 'org.pytorch:pytorch_android:1.9.0'
}
\ No newline at end of file
diff --git a/SpeechRecognition/app/src/main/java/org/pytorch/demo/speechrecognition/MainActivity.java b/SpeechRecognition/app/src/main/java/org/pytorch/demo/speechrecognition/MainActivity.java
index 1c3491c1..4215f206 100644
--- a/SpeechRecognition/app/src/main/java/org/pytorch/demo/speechrecognition/MainActivity.java
+++ b/SpeechRecognition/app/src/main/java/org/pytorch/demo/speechrecognition/MainActivity.java
@@ -35,13 +35,10 @@ public class MainActivity extends AppCompatActivity implements Runnable {
private TextView mTextView;
private Button mButton;
- private final static String[] tokens = {"", "", "", "", "|", "E", "T", "A", "O", "N", "I", "H", "S", "R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X", "J", "Q", "Z"};
- private final static int INPUT_SIZE = 65024;
- private final static int AUDIO_LEN_LIMIT = 6;
-
private final static int REQUEST_RECORD_AUDIO = 13;
+ private final static int AUDIO_LEN_IN_SECOND = 12;
private final static int SAMPLE_RATE = 16000;
- private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_LIMIT;
+ private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;
private final static String LOG_TAG = MainActivity.class.getSimpleName();
@@ -55,7 +52,7 @@ public void run() {
MainActivity.this.runOnUiThread(
() -> {
- mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_LIMIT - mStart));
+ mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND - mStart));
mStart += 1;
});
}
@@ -90,7 +87,7 @@ protected void onCreate(Bundle savedInstanceState) {
mButton.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
- mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_LIMIT));
+ mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND));
mButton.setEnabled(false);
Thread thread = new Thread(MainActivity.this);
@@ -197,46 +194,21 @@ public void run() {
private String recognize(float[] floatInputBuffer) {
if (mModuleEncoder == null) {
final String moduleFileAbsoluteFilePath = new File(
- assetFilePath(this, "wav2vec_traced_quantized.pt")).getAbsolutePath();
+ assetFilePath(this, "wav2vec2.pt")).getAbsolutePath();
mModuleEncoder = Module.load(moduleFileAbsoluteFilePath);
}
- double wav2vecinput[] = new double[INPUT_SIZE];
- for (int n = 0; n < INPUT_SIZE; n++)
+ double wav2vecinput[] = new double[RECORDING_LENGTH];
+ for (int n = 0; n < RECORDING_LENGTH; n++)
wav2vecinput[n] = floatInputBuffer[n];
- FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(INPUT_SIZE);
+ FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(RECORDING_LENGTH);
for (double val : wav2vecinput)
inTensorBuffer.put((float)val);
- Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, INPUT_SIZE});
-
- final Map map = mModuleEncoder.forward(IValue.from(inTensor)).toDictStringKey();
- final Tensor logitsTensor = map.get("logits").toTensor();
- final float[] values = logitsTensor.getDataAsFloatArray();
+ Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, RECORDING_LENGTH});
+ final String result = mModuleEncoder.forward(IValue.from(inTensor)).toStr();
- String result = "";
- float row[] = new float[tokens.length];
- for (int i = 0; i < values.length; i++) {
- row[i % tokens.length] = values[i];
- if (i > 0 && i % tokens.length == 0) {
- int tid = argmax(row);
- if (tid > 4) result = String.format("%s%s", result, tokens[tid]);
- else if (tid == 4) result = String.format("%s ", result);
- }
- }
return result;
}
-
- private int argmax(float[] array) {
- int maxIdx = 0;
- double maxVal = -Double.MAX_VALUE;
- for (int j=0; j maxVal) {
- maxVal = array[j];
- maxIdx = j;
- }
- }
- return maxIdx;
- }
}
\ No newline at end of file
diff --git a/SpeechRecognition/create_wav2vec2.py b/SpeechRecognition/create_wav2vec2.py
index 8c259f76..c477bbf2 100644
--- a/SpeechRecognition/create_wav2vec2.py
+++ b/SpeechRecognition/create_wav2vec2.py
@@ -1,24 +1,64 @@
-import soundfile as sf
import torch
-from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
+from torch import Tensor
from torch.utils.mobile_optimizer import optimize_for_mobile
+import torchaudio
+from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
+from transformers import Wav2Vec2ForCTC
-tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
-model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
-model.eval()
+# Wav2vec2 model emits sequences of probability (logits) distributions over the characters
+# The following class adds steps to decode the transcript (best path)
+class SpeechRecognizer(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ self.labels = [
+ "", "", "", "", "|", "E", "T", "A", "O", "N", "I", "H", "S",
+ "R", "D", "L", "U", "M", "W", "C", "F", "G", "Y", "P", "B", "V", "K", "'", "X",
+ "J", "Q", "Z"]
+
+ def forward(self, waveforms: Tensor) -> str:
+ """Given a single channel speech data, return transcription.
+
+ Args:
+ waveforms (Tensor): Speech tensor. Shape `[1, num_frames]`.
-audio_input, _ = sf.read("scent_of_a_woman_future.wav")
-input_values = tokenizer(audio_input, return_tensors="pt").input_values
-print(input_values.shape) # input_values is of 65024 long, matched INPUT_SIZE defined in Android code
+ Returns:
+ str: The resulting transcript
+ """
+ logits, _ = self.model(waveforms) # [batch, num_seq, num_label]
+ best_path = torch.argmax(logits[0], dim=-1) # [num_seq,]
+ prev = ''
+ hypothesis = ''
+ for i in best_path:
+ char = self.labels[i]
+ if char == prev:
+ continue
+ if char == '':
+ prev = ''
+ continue
+ hypothesis += char
+ prev = char
+ return hypothesis.replace('|', ' ')
+
+
+# Load Wav2Vec2 pretrained model from Hugging Face Hub
+model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
+# Convert the model to torchaudio format, which supports TorchScript.
+model = import_huggingface_model(model)
+# Remove weight normalization which is not supported by quantization.
+model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
+model = model.eval()
+# Attach decoder
+model = SpeechRecognizer(model)
-logits = model(input_values).logits
-predicted_ids = torch.argmax(logits, dim=-1)
-transcription = tokenizer.batch_decode(predicted_ids)[0]
-print(transcription)
+# Apply quantization / script / optimize for motbile
+quantized_model = torch.quantization.quantize_dynamic(
+ model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
+scripted_model = torch.jit.script(quantized_model)
+optimized_model = optimize_for_mobile(scripted_model)
-traced_model = torch.jit.trace(model, input_values, strict=False)
-model_dynamic_quantized = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
-traced_quantized_model = torch.jit.trace(model_dynamic_quantized, input_values, strict=False)
+# Sanity check
+waveform , _ = torchaudio.load('scent_of_a_woman_future.wav')
+print('Result:', optimized_model(waveform))
-optimized_traced_quantized_model = optimize_for_mobile(traced_quantized_model)
-optimized_traced_quantized_model.save("app/src/main/assets/wav2vec_traced_quantized.pt")
+optimized_model.save("wav2vec2.pt")