From c0c7d36e80b4b51d4406bc6d5f5ab7d672b6050e Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Tue, 23 Apr 2019 16:34:22 +0200 Subject: [PATCH] RNN VAD: clean-up unit tests - add test that checks that the computed VAD probability is within tolerance *1 - speed-up some tests by reducing the input length and skipping frames - remove unused code in test_utils - fix some comments *1: RnnVadTest::RnnBitExactness is replaced by RnnVadTest::RnnVadProbabilityWithinTolerance Bug: webrtc:10480 Change-Id: I19332d06eacffbbe671bf7749ff4c92798bdc55c Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/133910 Reviewed-by: Alex Loiko Commit-Queue: Alessio Bazzica Cr-Commit-Position: refs/heads/master@{#27803} --- .../audio_processing/agc2/rnn_vad/BUILD.gn | 2 - .../agc2/rnn_vad/auto_correlation_unittest.cc | 12 +- .../agc2/rnn_vad/lp_residual_unittest.cc | 61 ++++---- .../rnn_vad/pitch_search_internal_unittest.cc | 140 ++++++++++-------- .../agc2/rnn_vad/pitch_search_unittest.cc | 16 +- .../agc2/rnn_vad/rnn_unittest.cc | 42 +----- .../agc2/rnn_vad/rnn_vad_unittest.cc | 61 ++++++++ .../spectral_features_internal_unittest.cc | 12 +- .../rnn_vad/spectral_features_unittest.cc | 9 +- .../agc2/rnn_vad/test_utils.cc | 39 +---- .../agc2/rnn_vad/test_utils.h | 43 ++++-- .../agc2/rnn_vad/fft.dat.sha1 | 1 - .../agc2/rnn_vad/sil_features.dat.sha1 | 1 - .../agc2/rnn_vad/vad_prob.dat.sha1 | 2 +- 14 files changed, 246 insertions(+), 195 deletions(-) delete mode 100644 resources/audio_processing/agc2/rnn_vad/fft.dat.sha1 delete mode 100644 resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1 diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 902082ef7f..7cf6e3d3b6 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -64,12 +64,10 @@ if (rtc_include_tests) { unittest_resources = [ "../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat", - "../../../../resources/audio_processing/agc2/rnn_vad/fft.dat", "../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat", "../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat", "../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat", "../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm", - "../../../../resources/audio_processing/agc2/rnn_vad/sil_features.dat", "../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat", ] diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc index a5e456a4de..f66c0b299b 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc @@ -10,6 +10,7 @@ #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "test/gtest.h" @@ -18,6 +19,8 @@ namespace webrtc { namespace rnn_vad { namespace test { +// Checks that the auto correlation function produces output within tolerance +// given test input data. TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { PitchTestData test_data; std::array pitch_buf_decimated; @@ -35,7 +38,7 @@ TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { computed_output, 3e-3f); } -// Check that the auto correlation function computes the right thing for a +// Checks that the auto correlation function computes the right thing for a // simple use case. TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { // Create constant signal with no pitch. @@ -49,11 +52,12 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, computed_output); } - // The expected output is constantly the length of the fixed 'x' - // array in ComputePitchAutoCorrelation. + // The expected output is a vector filled with the same expected + // auto-correlation value. The latter equals the length of a 20 ms frame. + constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2; std::array expected_output; std::fill(expected_output.begin(), expected_output.end(), - kBufSize12kHz - kMaxPitch12kHz); + static_cast(kFrameSize20ms12kHz)); ExpectNearAbsolute(expected_output, computed_output, 4e-5f); } diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc index 47d8bf531f..1e80ee0631 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc @@ -10,7 +10,9 @@ #include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" +#include #include +#include #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" @@ -22,6 +24,7 @@ namespace webrtc { namespace rnn_vad { namespace test { +// Checks that the LP residual can be computed on an empty frame. TEST(RnnVadTest, LpResidualOfEmptyFrame) { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; @@ -37,44 +40,44 @@ TEST(RnnVadTest, LpResidualOfEmptyFrame) { ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual); } -// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. +// Checks that the computed LP residual is bit-exact given test input data. TEST(RnnVadTest, LpResidualPipelineBitExactness) { - // Pitch buffer 24 kHz data reader. + // Input and expected output readers. auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader(); - const size_t num_frames = pitch_buf_24kHz_reader.second; - std::array pitch_buf_data; - // Read ground-truth. auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); - ASSERT_EQ(num_frames, lp_residual_reader.second); - std::array expected_lp_residual; - rtc::ArrayView expected_lp_residual_view( - expected_lp_residual.data(), expected_lp_residual.size()); - // Init pipeline. + + // Buffers. + std::vector pitch_buf_data(kBufSize24kHz); std::array lpc_coeffs; - rtc::ArrayView lpc_coeffs_view( - lpc_coeffs.data(), kNumLpcCoefficients); - std::array computed_lp_residual; - rtc::ArrayView computed_lp_residual_view( - computed_lp_residual.data(), computed_lp_residual.size()); + std::vector computed_lp_residual(kBufSize24kHz); + std::vector expected_lp_residual(kBufSize24kHz); + + // Test length. + const size_t num_frames = std::min(pitch_buf_24kHz_reader.second, + static_cast(300)); // Max 3 s. + ASSERT_GE(lp_residual_reader.second, num_frames); + { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; for (size_t i = 0; i < num_frames; ++i) { - SCOPED_TRACE(i); - // Read input and expected output. - pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data); - lp_residual_reader.first->ReadChunk(expected_lp_residual_view); - // Skip pitch gain and period. + // Read input. + ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data)); + // Read expected output (ignore pitch gain and period). + ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual)); float unused; - lp_residual_reader.first->ReadValue(&unused); - lp_residual_reader.first->ReadValue(&unused); - // Run pipeline. - ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs_view); - ComputeLpResidual(lpc_coeffs_view, pitch_buf_data, - computed_lp_residual_view); - // Compare. - ExpectNearAbsolute(expected_lp_residual_view, computed_lp_residual_view, - kFloatMin); + ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused)); + ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused)); + + // Check every 200 ms. + if (i % 20 != 0) { + continue; + } + + SCOPED_TRACE(i); + ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs); + ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual); + ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin); } } } diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 7e29417baf..23ff49a2fc 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -23,39 +23,45 @@ namespace rnn_vad { namespace test { namespace { -constexpr std::array kTestPitchPeriods = { - 3 * kMinPitch48kHz / 2, - (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2, -}; -constexpr std::array kTestPitchGains = {0.35f, 0.75f}; +constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2; +constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2; + +constexpr float kTestPitchGainsLow = 0.35f; +constexpr float kTestPitchGainsHigh = 0.75f; } // namespace class ComputePitchGainThresholdTest : public ::testing::Test, - public ::testing::WithParamInterface< - std::tuple> {}; + public ::testing::WithParamInterface> {}; -TEST_P(ComputePitchGainThresholdTest, BitExactness) { +// Checks that the computed pitch gain is within tolerance given test input +// data. +TEST_P(ComputePitchGainThresholdTest, WithinTolerance) { const auto params = GetParam(); const size_t candidate_pitch_period = std::get<0>(params); const size_t pitch_period_ratio = std::get<1>(params); const size_t initial_pitch_period = std::get<2>(params); const float initial_pitch_gain = std::get<3>(params); const size_t prev_pitch_period = std::get<4>(params); - const size_t prev_pitch_gain = std::get<5>(params); + const float prev_pitch_gain = std::get<5>(params); const float threshold = std::get<6>(params); - { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - EXPECT_NEAR( threshold, ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio, initial_pitch_period, initial_pitch_gain, prev_pitch_period, prev_pitch_gain), - 3e-6f); + 5e-7f); } } @@ -77,7 +83,9 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f), std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); -TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { +// Checks that the frame-wise sliding square energy function produces output +// within tolerance given test input data. +TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) { PitchTestData test_data; std::array computed_output; { @@ -91,6 +99,7 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { computed_output, 3e-2f); } +// Checks that the estimated pitch period is bit-exact given test input data. TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { PitchTestData test_data; std::array pitch_buf_decimated; @@ -104,14 +113,13 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()}, pitch_buf_decimated, kMaxPitch12kHz); } - const std::array expected_output = {140, 142}; - EXPECT_EQ(expected_output, pitch_candidates_inv_lags); + EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast(140)); + EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast(142)); } +// Checks that the refined pitch period is bit-exact given test input data. TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { PitchTestData test_data; - std::array pitch_buf_decimated; - Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); size_t pitch_inv_lag; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -125,10 +133,17 @@ TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { class CheckLowerPitchPeriodsAndComputePitchGainTest : public ::testing::Test, - public ::testing::WithParamInterface< - std::tuple> {}; + public ::testing::WithParamInterface> {}; -TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { +// Checks that the computed pitch period is bit-exact and that the computed +// pitch gain is within tolerance given test input data. +TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, + PeriodBitExactnessGainWithinTolerance) { const auto params = GetParam(); const int initial_pitch_period = std::get<0>(params); const int prev_pitch_period = std::get<1>(params); @@ -147,48 +162,49 @@ TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { } } -INSTANTIATE_TEST_SUITE_P(RnnVadTest, - CheckLowerPitchPeriodsAndComputePitchGainTest, - ::testing::Values(std::make_tuple(kTestPitchPeriods[0], - kTestPitchPeriods[0], - kTestPitchGains[0], - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriods[0], - kTestPitchPeriods[0], - kTestPitchGains[1], - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriods[0], - kTestPitchPeriods[1], - kTestPitchGains[0], - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriods[0], - kTestPitchPeriods[1], - kTestPitchGains[1], - 91, - -0.0188608f), - std::make_tuple(kTestPitchPeriods[1], - kTestPitchPeriods[0], - kTestPitchGains[0], - 475, - -0.0904344f), - std::make_tuple(kTestPitchPeriods[1], - kTestPitchPeriods[0], - kTestPitchGains[1], - 475, - -0.0904344f), - std::make_tuple(kTestPitchPeriods[1], - kTestPitchPeriods[1], - kTestPitchGains[0], - 475, - -0.0904344f), - std::make_tuple(kTestPitchPeriods[1], - kTestPitchPeriods[1], - kTestPitchGains[1], - 475, - -0.0904344f))); +INSTANTIATE_TEST_SUITE_P( + RnnVadTest, + CheckLowerPitchPeriodsAndComputePitchGainTest, + ::testing::Values(std::make_tuple(kTestPitchPeriodsLow, + kTestPitchPeriodsLow, + kTestPitchGainsLow, + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriodsLow, + kTestPitchPeriodsLow, + kTestPitchGainsHigh, + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriodsLow, + kTestPitchPeriodsHigh, + kTestPitchGainsLow, + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriodsLow, + kTestPitchPeriodsHigh, + kTestPitchGainsHigh, + 91, + -0.0188608f), + std::make_tuple(kTestPitchPeriodsHigh, + kTestPitchPeriodsLow, + kTestPitchGainsLow, + 475, + -0.0904344f), + std::make_tuple(kTestPitchPeriodsHigh, + kTestPitchPeriodsLow, + kTestPitchGainsHigh, + 475, + -0.0904344f), + std::make_tuple(kTestPitchPeriodsHigh, + kTestPitchPeriodsHigh, + kTestPitchGainsLow, + 475, + -0.0904344f), + std::make_tuple(kTestPitchPeriodsHigh, + kTestPitchPeriodsHigh, + kTestPitchGainsHigh, + 475, + -0.0904344f))); } // namespace test } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index eac332edbf..494dfe7a98 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -12,7 +12,8 @@ #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" -#include +#include +#include #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -23,11 +24,13 @@ namespace webrtc { namespace rnn_vad { namespace test { -// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. -TEST(RnnVadTest, PitchSearchBitExactness) { +// Checks that the computed pitch period is bit-exact and that the computed +// pitch gain is within tolerance given test input data. +TEST(RnnVadTest, PitchSearchWithinTolerance) { auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); - const size_t num_frames = lp_residual_reader.second; - std::array lp_residual; + const size_t num_frames = std::min(lp_residual_reader.second, + static_cast(300)); // Max 3 s. + std::vector lp_residual(kBufSize24kHz); float expected_pitch_period, expected_pitch_gain; PitchEstimator pitch_estimator; { @@ -38,7 +41,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) { lp_residual_reader.first->ReadChunk(lp_residual); lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_gain); - PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual); + PitchInfo pitch_info = + pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz}); EXPECT_EQ(static_cast(expected_pitch_period), pitch_info.period); EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); } diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 289ce8d759..933b555402 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -8,9 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include #include -#include #include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" @@ -63,7 +61,8 @@ void TestGatedRecurrentLayer( } // namespace -// Bit-exactness check for fully connected layers. +// Checks that the output of a fully connected layer is within tolerance given +// test input data. TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { const std::array bias = {-50}; const std::array weights = { @@ -106,6 +105,8 @@ TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { } } +// Checks that the output of a GRU layer is within tolerance given test input +// data. TEST(RnnVadTest, CheckGatedRecurrentLayer) { const std::array bias = {96, -99, -81, -114, 49, 119, -118, 68, -76, 91, 121, 125}; @@ -139,41 +140,6 @@ TEST(RnnVadTest, CheckGatedRecurrentLayer) { } } -// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. -// Bit-exactness test checking that precomputed frame-wise features lead to the -// expected VAD probabilities. -TEST(RnnVadTest, RnnBitExactness) { - // Init. - auto features_reader = CreateSilenceFlagsFeatureMatrixReader(); - auto vad_probs_reader = CreateVadProbsReader(); - ASSERT_EQ(features_reader.second, vad_probs_reader.second); - const size_t num_frames = features_reader.second; - // Frame-wise buffers. - float expected_vad_probability; - float is_silence; - std::array features; - - // Compute VAD probability using the precomputed features. - RnnBasedVad vad; - for (size_t i = 0; i < num_frames; ++i) { - SCOPED_TRACE(i); - // Read frame data. - RTC_CHECK(vad_probs_reader.first->ReadValue(&expected_vad_probability)); - // The features file also includes a silence flag for each frame. - RTC_CHECK(features_reader.first->ReadValue(&is_silence)); - RTC_CHECK(features_reader.first->ReadChunk(features)); - // Compute and check VAD probability. - float vad_probability = vad.ComputeVadProbability(features, is_silence); - ASSERT_TRUE(is_silence == 0.f || is_silence == 1.f); - if (is_silence == 1.f) { - ASSERT_EQ(0.f, expected_vad_probability); - EXPECT_EQ(0.f, vad_probability); - } else { - EXPECT_NEAR(expected_vad_probability, vad_probability, 3e-6f); - } - } -} - } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc index 4afe24b9f1..8583d4bc1b 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc @@ -9,6 +9,7 @@ */ #include +#include #include #include "common_audio/resampler/push_sinc_resampler.h" @@ -43,8 +44,68 @@ void DumpPerfStats(size_t num_samples, RTC_LOG(LS_INFO) << "speed: " << speed << "x"; } +// When the RNN VAD model is updated and the expected output changes, set the +// constant below to true in order to write new expected output binary files. +constexpr bool kWriteComputedOutputToFile = false; + } // namespace +// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false +// when the expected output files are re-exported. +TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) { + ASSERT_FALSE(kWriteComputedOutputToFile) + << "Cannot land if kWriteComputedOutput is true."; +} + +// Checks that the computed VAD probability for a test input sequence sampled at +// 48 kHz is within tolerance. +TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) { + // Init resampler, feature extractor and RNN. + PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz); + FeaturesExtractor features_extractor; + RnnBasedVad rnn_vad; + + // Init input samples and expected output readers. + auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); + auto expected_vad_prob_reader = CreateVadProbsReader(); + + // Input length. + const size_t num_frames = samples_reader.second; + ASSERT_GE(expected_vad_prob_reader.second, num_frames); + + // Init buffers. + std::vector samples_48k(kFrameSize10ms48kHz); + std::vector samples_24k(kFrameSize10ms24kHz); + std::vector feature_vector(kFeatureVectorSize); + std::vector computed_vad_prob(num_frames); + std::vector expected_vad_prob(num_frames); + + // Read expected output. + ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob)); + + // Compute VAD probabilities on the downsampled input. + float cumulative_error = 0.f; + for (size_t i = 0; i < num_frames; ++i) { + samples_reader.first->ReadChunk(samples_48k); + decimator.Resample(samples_48k.data(), samples_48k.size(), + samples_24k.data(), samples_24k.size()); + bool is_silence = features_extractor.CheckSilenceComputeFeatures( + {samples_24k.data(), kFrameSize10ms24kHz}, + {feature_vector.data(), kFeatureVectorSize}); + computed_vad_prob[i] = rnn_vad.ComputeVadProbability( + {feature_vector.data(), kFeatureVectorSize}, is_silence); + EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f); + cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]); + } + // Check average error. + EXPECT_LT(cumulative_error / num_frames, 1e-4f); + + if (kWriteComputedOutputToFile) { + BinaryFileWriter vad_prob_writer("new_vad_prob.dat"); + vad_prob_writer.WriteChunk(computed_vad_prob); + } +} + // Performance test for the RNN VAD (pre-fetching and downsampling are // excluded). Keep disabled and only enable locally to measure performance as // follows: diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc index d112eb713f..ec81295094 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc @@ -85,14 +85,18 @@ TEST(RnnVadTest, DISABLED_TestOpusScaleWeights) { } } +// Checks that the computed band-wise auto-correlation is non-negative for a +// simple input vector of FFT coefficients. TEST(RnnVadTest, SpectralCorrelatorValidOutput) { - SpectralCorrelator e; + // Input: vector of (1, 1j) values. Pffft fft(kFrameSize20ms24kHz, Pffft::FftType::kReal); auto in = fft.CreateBuffer(); std::array out; auto in_view = in->GetView(); std::fill(in_view.begin(), in_view.end(), 1.f); in_view[1] = 0.f; // Nyquist frequency. + // Compute and check output. + SpectralCorrelator e; e.ComputeAutoCorrelation(in_view, out); for (size_t i = 0; i < kOpusBands24kHz; ++i) { SCOPED_TRACE(i); @@ -100,6 +104,8 @@ TEST(RnnVadTest, SpectralCorrelatorValidOutput) { } } +// Checks that the computed smoothed log magnitude spectrum is within tolerance +// given hard-coded test input data. TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) { constexpr std::array input = { {86.060539245605f, 275.668334960938f, 43.406528472900f, 6.541896820068f, @@ -124,7 +130,9 @@ TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) { } } -TEST(RnnVadTest, ComputeDctBitExactness) { +// Checks that the computed DCT is within tolerance given hard-coded test input +// data. +TEST(RnnVadTest, ComputeDctWithinTolerance) { constexpr std::array input = { {0.232155621052f, 0.678957760334f, 0.220818966627f, -0.077363930643f, -0.559227049351f, 0.432545185089f, 0.353900641203f, 0.398993015289f, diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc index 39b9f93eb1..bc00e2c500 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc @@ -67,6 +67,8 @@ constexpr float kInitialFeatureVal = -9999.f; } // namespace +// Checks that silence is detected when the input signal is 0 and that the +// feature vector is written only if the input signal is not tagged as silence. TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) { // Initialize. SpectralFeaturesExtractor sfe; @@ -108,9 +110,10 @@ TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) { [](float x) { return x == kInitialFeatureVal; })); } -// When the input signal does not change, the cepstral coefficients average does -// not change and the derivatives are zero. Similarly, the cepstral variability -// score does not change either. +// Feeds a constant input signal and checks that: +// - the cepstral coefficients average does not change; +// - the derivatives are zero; +// - the cepstral variability score does not change. TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) { // Initialize. SpectralFeaturesExtractor sfe; diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 14b84a461c..8236d5f750 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -46,13 +46,6 @@ void ExpectNearAbsolute(rtc::ArrayView expected, } } -std::unique_ptr> CreatePitchSearchTestDataReader() { - constexpr size_t cols = 1396; - return absl::make_unique>( - ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"), - cols); -} - std::pair>, const size_t> CreatePcmSamplesReader(const size_t frame_length) { auto ptr = absl::make_unique>( @@ -78,25 +71,6 @@ ReaderPairType CreateLpResidualAndPitchPeriodGainReader() { rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)}; } -ReaderPairType CreateFftCoeffsReader() { - constexpr size_t num_fft_points = 481; - constexpr size_t row_size = 2 * num_fft_points; // Real and imaginary values. - auto ptr = absl::make_unique>( - test::ResourcePath("audio_processing/agc2/rnn_vad/fft", "dat"), - num_fft_points); - return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), row_size)}; -} - -ReaderPairType CreateSilenceFlagsFeatureMatrixReader() { - constexpr size_t feature_vector_size = 42; - auto ptr = absl::make_unique>( - test::ResourcePath("audio_processing/agc2/rnn_vad/sil_features", "dat"), - feature_vector_size); - // Features and silence flag. - return {std::move(ptr), - rtc::CheckedDivExact(ptr->data_length(), feature_vector_size + 1)}; -} - ReaderPairType CreateVadProbsReader() { auto ptr = absl::make_unique>( test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat")); @@ -104,23 +78,26 @@ ReaderPairType CreateVadProbsReader() { } PitchTestData::PitchTestData() { - auto test_data_reader = CreatePitchSearchTestDataReader(); - test_data_reader->ReadChunk(test_data_); + BinaryFileReader test_data_reader( + ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"), + static_cast(1396)); + test_data_reader.ReadChunk(test_data_); } PitchTestData::~PitchTestData() = default; -rtc::ArrayView PitchTestData::GetPitchBufView() { +rtc::ArrayView PitchTestData::GetPitchBufView() + const { return {test_data_.data(), kBufSize24kHz}; } rtc::ArrayView -PitchTestData::GetPitchBufSquareEnergiesView() { +PitchTestData::GetPitchBufSquareEnergiesView() const { return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; } rtc::ArrayView -PitchTestData::GetPitchBufAutoCorrCoeffsView() { +PitchTestData::GetPitchBufAutoCorrCoeffsView() const { return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, kNumPitchBufAutoCorrCoeffs}; } diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index c11af7f8a3..fbb270faf8 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -46,11 +47,10 @@ void ExpectNearAbsolute(rtc::ArrayView expected, template class BinaryFileReader { public: - explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 1) + explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0) : is_(file_path, std::ios::binary | std::ios::ate), data_length_(is_.tellg() / sizeof(T)), chunk_size_(chunk_size) { - RTC_CHECK_LT(0, chunk_size_); RTC_CHECK(is_); SeekBeginning(); buf_.resize(chunk_size_); @@ -69,9 +69,11 @@ class BinaryFileReader { } return is_.gcount() == sizeof(T); } + // If |chunk_size| was specified in the ctor, it will check that the size of + // |dst| equals |chunk_size|. bool ReadChunk(rtc::ArrayView dst) { - RTC_DCHECK_EQ(chunk_size_, dst.size()); - const std::streamsize bytes_to_read = chunk_size_ * sizeof(T); + RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size())); + const std::streamsize bytes_to_read = dst.size() * sizeof(T); if (std::is_same::value) { is_.read(reinterpret_cast(dst.data()), bytes_to_read); } else { @@ -91,9 +93,26 @@ class BinaryFileReader { std::vector buf_; }; +// Writer for binary files. +template +class BinaryFileWriter { + public: + explicit BinaryFileWriter(const std::string& file_path) + : os_(file_path, std::ios::binary) {} + BinaryFileWriter(const BinaryFileWriter&) = delete; + BinaryFileWriter& operator=(const BinaryFileWriter&) = delete; + ~BinaryFileWriter() = default; + static_assert(std::is_arithmetic::value, ""); + void WriteChunk(rtc::ArrayView value) { + const std::streamsize bytes_to_write = value.size() * sizeof(T); + os_.write(reinterpret_cast(value.data()), bytes_to_write); + } + + private: + std::ofstream os_; +}; + // Factories for resource file readers. -// Creates a reader for the pitch search test data. -std::unique_ptr> CreatePitchSearchTestDataReader(); // The functions below return a pair where the first item is a reader unique // pointer and the second the number of chunks that can be read from the file. // Creates a reader for the PCM samples that casts from S16 to float and reads @@ -107,12 +126,6 @@ CreatePitchBuffer24kHzReader(); // and gain values. std::pair>, const size_t> CreateLpResidualAndPitchPeriodGainReader(); -// Creates a reader for the FFT coefficients. -std::pair>, const size_t> -CreateFftCoeffsReader(); -// Creates a reader for the silence flags and the feature matrix. -std::pair>, const size_t> -CreateSilenceFlagsFeatureMatrixReader(); // Creates a reader for the VAD probabilities. std::pair>, const size_t> CreateVadProbsReader(); @@ -128,11 +141,11 @@ class PitchTestData { public: PitchTestData(); ~PitchTestData(); - rtc::ArrayView GetPitchBufView(); + rtc::ArrayView GetPitchBufView() const; rtc::ArrayView - GetPitchBufSquareEnergiesView(); + GetPitchBufSquareEnergiesView() const; rtc::ArrayView - GetPitchBufAutoCorrCoeffsView(); + GetPitchBufAutoCorrCoeffsView() const; private: std::array test_data_; diff --git a/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1 deleted file mode 100644 index ebd5124d59..0000000000 --- a/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1 +++ /dev/null @@ -1 +0,0 @@ -e62364d35abd123663bfc800fa233071d6d7fffd \ No newline at end of file diff --git a/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1 deleted file mode 100644 index bc591e9d6c..0000000000 --- a/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1 +++ /dev/null @@ -1 +0,0 @@ -e0a92782c2903be9da10385d924d34e8bf212d5e \ No newline at end of file diff --git a/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1 index 1aa3bd0d83..8ee78b101a 100644 --- a/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1 +++ b/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1 @@ -1 +1 @@ -05735ede0b457318e307d12f5acfd11bbbbd0afd \ No newline at end of file +68640327266262c3fe047ec7f07a46a355ff90b9 \ No newline at end of file