diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 29cdfeb2a9..ef2370c878 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -86,6 +86,7 @@ rtc_source_set("rnn_vad_layers") { ] deps = [ ":rnn_vad_common", + ":vector_math", "..:cpu_features", "../../../../api:array_view", "../../../../api:function_view", @@ -94,6 +95,9 @@ rtc_source_set("rnn_vad_layers") { "../../../../rtc_base/system:arch", "//third_party/rnnoise:rnn_vad", ] + if (current_cpu == "x86" || current_cpu == "x64") { + deps += [ ":vector_math_avx2" ] + } absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] } diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index c1bded1af3..f828a248c3 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -50,6 +50,7 @@ RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features) kHiddenGruBias, kHiddenGruWeights, kHiddenGruRecurrentWeights, + cpu_features, /*layer_name=*/"GRU1"), output_(kHiddenLayerOutputSize, kOutputLayerOutputSize, diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc index c586ed291f..900ce63121 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc @@ -46,12 +46,12 @@ constexpr std::array kFullyConnectedExpectedOutput = { 0.983443f, 0.999991f, -0.824335f, 0.984742f, 0.990208f, 0.938179f, 0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f}; -class RnnParametrization +class RnnFcParametrization : public ::testing::TestWithParam {}; // Checks that the output of a fully connected layer is within tolerance given // test input data. -TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) { +TEST_P(RnnFcParametrization, CheckFullyConnectedLayerOutput) { FullyConnectedLayer fc(kInputLayerInputSize, kInputLayerOutputSize, kInputDenseBias, kInputDenseWeights, ActivationFunction::kTansigApproximated, @@ -61,7 +61,7 @@ TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) { ExpectNearAbsolute(kFullyConnectedExpectedOutput, fc, 1e-5f); } -TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) { +TEST_P(RnnFcParametrization, DISABLED_BenchmarkFullyConnectedLayer) { const AvailableCpuFeatures cpu_features = GetParam(); FullyConnectedLayer fc(kInputLayerInputSize, kInputLayerOutputSize, kInputDenseBias, kInputDenseWeights, @@ -87,16 +87,14 @@ std::vector GetCpuFeaturesToTest() { v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); AvailableCpuFeatures available = GetAvailableCpuFeatures(); if (available.sse2) { - AvailableCpuFeatures features( - {/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); - v.push_back(features); + v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); } return v; } INSTANTIATE_TEST_SUITE_P( RnnVadTest, - RnnParametrization, + RnnFcParametrization, ::testing::ValuesIn(GetCpuFeaturesToTest()), [](const ::testing::TestParamInfo& info) { return info.param.ToString(); diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc index f37fc2af51..482016e8d3 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc @@ -43,47 +43,79 @@ std::vector PreprocessGruTensor(rtc::ArrayView tensor_src, return tensor_dst; } -void ComputeGruUpdateResetGates(int input_size, - int output_size, - rtc::ArrayView weights, - rtc::ArrayView recurrent_weights, - rtc::ArrayView bias, - rtc::ArrayView input, - rtc::ArrayView state, - rtc::ArrayView gate) { +// Computes the output for the update or the reset gate. +// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where +// - `g`: output gate vector +// - `W`: weights matrix +// - `i`: input vector +// - `R`: recurrent weights matrix +// - `s`: state gate vector +// - `b`: bias vector +void ComputeUpdateResetGate(int input_size, + int output_size, + const VectorMath& vector_math, + rtc::ArrayView input, + rtc::ArrayView state, + rtc::ArrayView bias, + rtc::ArrayView weights, + rtc::ArrayView recurrent_weights, + rtc::ArrayView gate) { + RTC_DCHECK_EQ(input.size(), input_size); + RTC_DCHECK_EQ(state.size(), output_size); + RTC_DCHECK_EQ(bias.size(), output_size); + RTC_DCHECK_EQ(weights.size(), input_size * output_size); + RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); + RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated. for (int o = 0; o < output_size; ++o) { - gate[o] = bias[o]; - for (int i = 0; i < input_size; ++i) { - gate[o] += input[i] * weights[o * input_size + i]; - } - for (int s = 0; s < output_size; ++s) { - gate[o] += state[s] * recurrent_weights[o * output_size + s]; - } - gate[o] = ::rnnoise::SigmoidApproximated(gate[o]); + float x = bias[o]; + x += vector_math.DotProduct(input, + weights.subview(o * input_size, input_size)); + x += vector_math.DotProduct( + state, recurrent_weights.subview(o * output_size, output_size)); + gate[o] = ::rnnoise::SigmoidApproximated(x); } } -void ComputeGruOutputGate(int input_size, - int output_size, - rtc::ArrayView weights, - rtc::ArrayView recurrent_weights, - rtc::ArrayView bias, - rtc::ArrayView input, - rtc::ArrayView state, - rtc::ArrayView reset, - rtc::ArrayView gate) { +// Computes the output for the state gate. +// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where +// - `s'`: output state gate vector +// - `s`: previous state gate vector +// - `u`: update gate vector +// - `W`: weights matrix +// - `i`: input vector +// - `R`: recurrent weights matrix +// - `r`: reset gate vector +// - `b`: bias vector +// - `.*` element-wise product +void ComputeStateGate(int input_size, + int output_size, + const VectorMath& vector_math, + rtc::ArrayView input, + rtc::ArrayView update, + rtc::ArrayView reset, + rtc::ArrayView bias, + rtc::ArrayView weights, + rtc::ArrayView recurrent_weights, + rtc::ArrayView state) { + RTC_DCHECK_EQ(input.size(), input_size); + RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated. + RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated. + RTC_DCHECK_EQ(bias.size(), output_size); + RTC_DCHECK_EQ(weights.size(), input_size * output_size); + RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); + RTC_DCHECK_EQ(state.size(), output_size); + std::array reset_x_state; for (int o = 0; o < output_size; ++o) { - gate[o] = bias[o]; - for (int i = 0; i < input_size; ++i) { - gate[o] += input[i] * weights[o * input_size + i]; - } - for (int s = 0; s < output_size; ++s) { - gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s]; - } - // Rectified linear unit. - if (gate[o] < 0.f) { - gate[o] = 0.f; - } + reset_x_state[o] = state[o] * reset[o]; + } + for (int o = 0; o < output_size; ++o) { + float x = bias[o]; + x += vector_math.DotProduct(input, + weights.subview(o * input_size, input_size)); + x += vector_math.DotProduct( + {reset_x_state.data(), static_cast(output_size)}, + recurrent_weights.subview(o * output_size, output_size)); + state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x); } } @@ -95,12 +127,14 @@ GatedRecurrentLayer::GatedRecurrentLayer( const rtc::ArrayView bias, const rtc::ArrayView weights, const rtc::ArrayView recurrent_weights, + const AvailableCpuFeatures& cpu_features, absl::string_view layer_name) : input_size_(input_size), output_size_(output_size), bias_(PreprocessGruTensor(bias, output_size)), weights_(PreprocessGruTensor(weights, output_size)), - recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)) { + recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)), + vector_math_(cpu_features) { RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits) << "Insufficient GRU layer over-allocation (" << layer_name << ")."; RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size()) @@ -126,44 +160,38 @@ void GatedRecurrentLayer::Reset() { void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { RTC_DCHECK_EQ(input.size(), input_size_); - // TODO(bugs.chromium.org/10480): Add AVX2. - // TODO(bugs.chromium.org/10480): Add Neon. - - // Stride and offset used to read parameter arrays. - const int stride_in = input_size_ * output_size_; - const int stride_out = output_size_ * output_size_; - + // The tensors below are organized as a sequence of flattened tensors for the + // `update`, `reset` and `state` gates. rtc::ArrayView bias(bias_); rtc::ArrayView weights(weights_); rtc::ArrayView recurrent_weights(recurrent_weights_); + // Strides to access to the flattened tensors for a specific gate. + const int stride_weights = input_size_ * output_size_; + const int stride_recurrent_weights = output_size_ * output_size_; + + rtc::ArrayView state(state_.data(), output_size_); // Update gate. std::array update; - ComputeGruUpdateResetGates( - input_size_, output_size_, weights.subview(0, stride_in), - recurrent_weights.subview(0, stride_out), bias.subview(0, output_size_), - input, state_, update); - + ComputeUpdateResetGate( + input_size_, output_size_, vector_math_, input, state, + bias.subview(0, output_size_), weights.subview(0, stride_weights), + recurrent_weights.subview(0, stride_recurrent_weights), update); // Reset gate. std::array reset; - ComputeGruUpdateResetGates( - input_size_, output_size_, weights.subview(stride_in, stride_in), - recurrent_weights.subview(stride_out, stride_out), - bias.subview(output_size_, output_size_), input, state_, reset); - - // Output gate. - std::array output; - ComputeGruOutputGate(input_size_, output_size_, - weights.subview(2 * stride_in, stride_in), - recurrent_weights.subview(2 * stride_out, stride_out), - bias.subview(2 * output_size_, output_size_), input, - state_, reset, output); - - // Update output through the update gates and update the state. - for (int o = 0; o < output_size_; ++o) { - output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o]; - state_[o] = output[o]; - } + ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state, + bias.subview(output_size_, output_size_), + weights.subview(stride_weights, stride_weights), + recurrent_weights.subview(stride_recurrent_weights, + stride_recurrent_weights), + reset); + // State gate. + ComputeStateGate(input_size_, output_size_, vector_math_, input, update, + reset, bias.subview(2 * output_size_, output_size_), + weights.subview(2 * stride_weights, stride_weights), + recurrent_weights.subview(2 * stride_recurrent_weights, + stride_recurrent_weights), + state); } } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h index f66b048b7d..3407dfcdf1 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h @@ -17,6 +17,7 @@ #include "absl/strings/string_view.h" #include "api/array_view.h" #include "modules/audio_processing/agc2/cpu_features.h" +#include "modules/audio_processing/agc2/rnn_vad/vector_math.h" namespace webrtc { namespace rnn_vad { @@ -34,6 +35,7 @@ class GatedRecurrentLayer { rtc::ArrayView bias, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, + const AvailableCpuFeatures& cpu_features, absl::string_view layer_name); GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; @@ -57,6 +59,7 @@ class GatedRecurrentLayer { const std::vector bias_; const std::vector weights_; const std::vector recurrent_weights_; + const VectorMath vector_math_; // Over-allocated array with size equal to `output_size_`. std::array state_; }; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc index 4e8b524d6f..ee8bdac994 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc @@ -11,6 +11,8 @@ #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h" #include +#include +#include #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" @@ -18,6 +20,7 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "test/gtest.h" +#include "third_party/rnnoise/src/rnn_vad_weights.h" namespace webrtc { namespace rnn_vad { @@ -101,24 +104,44 @@ constexpr std::array kGruExpectedOutputSequence = { 0.00781069f, 0.75267816f, 0.f, 0.02579715f, 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f}; +class RnnGruParametrization + : public ::testing::TestWithParam {}; + // Checks that the output of a GRU layer is within tolerance given test input // data. -TEST(RnnVadTest, CheckGatedRecurrentLayer) { +TEST_P(RnnGruParametrization, CheckGatedRecurrentLayer) { GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, /*layer_name=*/"GRU"); + kGruRecurrentWeights, + /*cpu_features=*/GetParam(), + /*layer_name=*/"GRU"); TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence); } -TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { - GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, /*layer_name=*/"GRU"); +TEST_P(RnnGruParametrization, DISABLED_BenchmarkGatedRecurrentLayer) { + // Prefetch test data. + std::unique_ptr reader = CreateGruInputReader(); + std::vector gru_input_sequence(reader->size()); + reader->ReadChunk(gru_input_sequence); + + using ::rnnoise::kHiddenGruBias; + using ::rnnoise::kHiddenGruRecurrentWeights; + using ::rnnoise::kHiddenGruWeights; + using ::rnnoise::kHiddenLayerOutputSize; + using ::rnnoise::kInputLayerOutputSize; - rtc::ArrayView input_sequence(kGruInputSequence); - static_assert(kGruInputSequence.size() % kGruInputSize == 0, ""); - constexpr int input_sequence_length = - kGruInputSequence.size() / kGruInputSize; + GatedRecurrentLayer gru(kInputLayerOutputSize, kHiddenLayerOutputSize, + kHiddenGruBias, kHiddenGruWeights, + kHiddenGruRecurrentWeights, + /*cpu_features=*/GetParam(), + /*layer_name=*/"GRU"); - constexpr int kNumTests = 10000; + rtc::ArrayView input_sequence(gru_input_sequence); + ASSERT_EQ(input_sequence.size() % kInputLayerOutputSize, + static_cast(0)); + const int input_sequence_length = + input_sequence.size() / kInputLayerOutputSize; + + constexpr int kNumTests = 100; ::webrtc::test::PerformanceTimer perf_timer(kNumTests); for (int k = 0; k < kNumTests; ++k) { perf_timer.StartTimer(); @@ -133,6 +156,28 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { << " ms"; } +// Finds the relevant CPU features combinations to test. +std::vector GetCpuFeaturesToTest() { + std::vector v; + AvailableCpuFeatures available = GetAvailableCpuFeatures(); + v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false}); + if (available.avx2) { + v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false}); + } + if (available.sse2) { + v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false}); + } + return v; +} + +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + RnnGruParametrization, + ::testing::ValuesIn(GetCpuFeaturesToTest()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + } // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 3db6774450..b8ca9c3669 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -111,6 +111,12 @@ ChunksFileReader CreateLpResidualAndPitchInfoReader() { return {kChunkSize, num_chunks, std::move(reader)}; } +std::unique_ptr CreateGruInputReader() { + return std::make_unique>( + /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in", + "dat")); +} + std::unique_ptr CreateVadProbsReader() { return std::make_unique>( /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index 86af5e0076..e366e1837e 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -77,6 +77,9 @@ ChunksFileReader CreatePitchBuffer24kHzReader(); // Creates a reader for the LP residual and pitch information test data. ChunksFileReader CreateLpResidualAndPitchInfoReader(); +// Creates a reader for the sequence of GRU input vectors. +std::unique_ptr CreateGruInputReader(); + // Creates a reader for the VAD probabilities test data. std::unique_ptr CreateVadProbsReader(); diff --git a/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1 new file mode 100644 index 0000000000..f78c40e6c4 --- /dev/null +++ b/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1 @@ -0,0 +1 @@ +402abf7a4e5d35abb78906fff2b3f4d8d24aa629 \ No newline at end of file