RNN VAD: clean-up unit tests

- add test that checks that the computed VAD probability is within
  tolerance *1
- speed-up some tests by reducing the input length and skipping frames
- remove unused code in test_utils
- fix some comments

*1: RnnVadTest::RnnBitExactness is replaced by
    RnnVadTest::RnnVadProbabilityWithinTolerance

Bug: webrtc:10480
Change-Id: I19332d06eacffbbe671bf7749ff4c92798bdc55c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/133910
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27803}
This commit is contained in:
Alessio Bazzica 2019-04-23 16:34:22 +02:00 committed by Commit Bot
parent fc02a793c2
commit c0c7d36e80
14 changed files with 246 additions and 195 deletions

View file

@ -64,12 +64,10 @@ if (rtc_include_tests) {
unittest_resources = [ unittest_resources = [
"../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat", "../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/fft.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat", "../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat", "../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat", "../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm", "../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm",
"../../../../resources/audio_processing/agc2/rnn_vad/sil_features.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat", "../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat",
] ]

View file

@ -10,6 +10,7 @@
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "test/gtest.h" #include "test/gtest.h"
@ -18,6 +19,8 @@ namespace webrtc {
namespace rnn_vad { namespace rnn_vad {
namespace test { namespace test {
// Checks that the auto correlation function produces output within tolerance
// given test input data.
TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
PitchTestData test_data; PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated; std::array<float, kBufSize12kHz> pitch_buf_decimated;
@ -35,7 +38,7 @@ TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
computed_output, 3e-3f); computed_output, 3e-3f);
} }
// Check that the auto correlation function computes the right thing for a // Checks that the auto correlation function computes the right thing for a
// simple use case. // simple use case.
TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
// Create constant signal with no pitch. // Create constant signal with no pitch.
@ -49,11 +52,12 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
computed_output); computed_output);
} }
// The expected output is constantly the length of the fixed 'x' // The expected output is a vector filled with the same expected
// array in ComputePitchAutoCorrelation. // auto-correlation value. The latter equals the length of a 20 ms frame.
constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output; std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
std::fill(expected_output.begin(), expected_output.end(), std::fill(expected_output.begin(), expected_output.end(),
kBufSize12kHz - kMaxPitch12kHz); static_cast<float>(kFrameSize20ms12kHz));
ExpectNearAbsolute(expected_output, computed_output, 4e-5f); ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
} }

View file

@ -10,7 +10,9 @@
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h" #include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
#include <algorithm>
#include <array> #include <array>
#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@ -22,6 +24,7 @@ namespace webrtc {
namespace rnn_vad { namespace rnn_vad {
namespace test { namespace test {
// Checks that the LP residual can be computed on an empty frame.
TEST(RnnVadTest, LpResidualOfEmptyFrame) { TEST(RnnVadTest, LpResidualOfEmptyFrame) {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer; // FloatingPointExceptionObserver fpe_observer;
@ -37,44 +40,44 @@ TEST(RnnVadTest, LpResidualOfEmptyFrame) {
ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual); ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual);
} }
// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. // Checks that the computed LP residual is bit-exact given test input data.
TEST(RnnVadTest, LpResidualPipelineBitExactness) { TEST(RnnVadTest, LpResidualPipelineBitExactness) {
// Pitch buffer 24 kHz data reader. // Input and expected output readers.
auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader(); auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader();
const size_t num_frames = pitch_buf_24kHz_reader.second;
std::array<float, kBufSize24kHz> pitch_buf_data;
// Read ground-truth.
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
ASSERT_EQ(num_frames, lp_residual_reader.second);
std::array<float, kBufSize24kHz> expected_lp_residual; // Buffers.
rtc::ArrayView<float, kBufSize24kHz> expected_lp_residual_view( std::vector<float> pitch_buf_data(kBufSize24kHz);
expected_lp_residual.data(), expected_lp_residual.size());
// Init pipeline.
std::array<float, kNumLpcCoefficients> lpc_coeffs; std::array<float, kNumLpcCoefficients> lpc_coeffs;
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs_view( std::vector<float> computed_lp_residual(kBufSize24kHz);
lpc_coeffs.data(), kNumLpcCoefficients); std::vector<float> expected_lp_residual(kBufSize24kHz);
std::array<float, kBufSize24kHz> computed_lp_residual;
rtc::ArrayView<float, kBufSize24kHz> computed_lp_residual_view( // Test length.
computed_lp_residual.data(), computed_lp_residual.size()); const size_t num_frames = std::min(pitch_buf_24kHz_reader.second,
static_cast<size_t>(300)); // Max 3 s.
ASSERT_GE(lp_residual_reader.second, num_frames);
{ {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer; // FloatingPointExceptionObserver fpe_observer;
for (size_t i = 0; i < num_frames; ++i) { for (size_t i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i); // Read input.
// Read input and expected output. ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data); // Read expected output (ignore pitch gain and period).
lp_residual_reader.first->ReadChunk(expected_lp_residual_view); ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual));
// Skip pitch gain and period.
float unused; float unused;
lp_residual_reader.first->ReadValue(&unused); ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
lp_residual_reader.first->ReadValue(&unused); ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
// Run pipeline.
ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs_view); // Check every 200 ms.
ComputeLpResidual(lpc_coeffs_view, pitch_buf_data, if (i % 20 != 0) {
computed_lp_residual_view); continue;
// Compare. }
ExpectNearAbsolute(expected_lp_residual_view, computed_lp_residual_view,
kFloatMin); SCOPED_TRACE(i);
ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs);
ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual);
ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin);
} }
} }
} }

View file

@ -23,39 +23,45 @@ namespace rnn_vad {
namespace test { namespace test {
namespace { namespace {
constexpr std::array<int, 2> kTestPitchPeriods = { constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
3 * kMinPitch48kHz / 2, constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
(3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
}; constexpr float kTestPitchGainsLow = 0.35f;
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f}; constexpr float kTestPitchGainsHigh = 0.75f;
} // namespace } // namespace
class ComputePitchGainThresholdTest class ComputePitchGainThresholdTest
: public ::testing::Test, : public ::testing::Test,
public ::testing::WithParamInterface< public ::testing::WithParamInterface<std::tuple<
std::tuple<size_t, size_t, size_t, float, size_t, float, float>> {}; /*candidate_pitch_period=*/size_t,
/*pitch_period_ratio=*/size_t,
/*initial_pitch_period=*/size_t,
/*initial_pitch_gain=*/float,
/*prev_pitch_period=*/size_t,
/*prev_pitch_gain=*/float,
/*threshold=*/float>> {};
TEST_P(ComputePitchGainThresholdTest, BitExactness) { // Checks that the computed pitch gain is within tolerance given test input
// data.
TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
const auto params = GetParam(); const auto params = GetParam();
const size_t candidate_pitch_period = std::get<0>(params); const size_t candidate_pitch_period = std::get<0>(params);
const size_t pitch_period_ratio = std::get<1>(params); const size_t pitch_period_ratio = std::get<1>(params);
const size_t initial_pitch_period = std::get<2>(params); const size_t initial_pitch_period = std::get<2>(params);
const float initial_pitch_gain = std::get<3>(params); const float initial_pitch_gain = std::get<3>(params);
const size_t prev_pitch_period = std::get<4>(params); const size_t prev_pitch_period = std::get<4>(params);
const size_t prev_pitch_gain = std::get<5>(params); const float prev_pitch_gain = std::get<5>(params);
const float threshold = std::get<6>(params); const float threshold = std::get<6>(params);
{ {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer; // FloatingPointExceptionObserver fpe_observer;
EXPECT_NEAR( EXPECT_NEAR(
threshold, threshold,
ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio, ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
initial_pitch_period, initial_pitch_gain, initial_pitch_period, initial_pitch_gain,
prev_pitch_period, prev_pitch_gain), prev_pitch_period, prev_pitch_gain),
3e-6f); 5e-7f);
} }
} }
@ -77,7 +83,9 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f), std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { // Checks that the frame-wise sliding square energy function produces output
// within tolerance given test input data.
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
PitchTestData test_data; PitchTestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output; std::array<float, kNumPitchBufSquareEnergies> computed_output;
{ {
@ -91,6 +99,7 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
computed_output, 3e-2f); computed_output, 3e-2f);
} }
// Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
PitchTestData test_data; PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated; std::array<float, kBufSize12kHz> pitch_buf_decimated;
@ -104,14 +113,13 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()}, FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
pitch_buf_decimated, kMaxPitch12kHz); pitch_buf_decimated, kMaxPitch12kHz);
} }
const std::array<size_t, 2> expected_output = {140, 142}; EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
EXPECT_EQ(expected_output, pitch_candidates_inv_lags); EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
} }
// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
PitchTestData test_data; PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
size_t pitch_inv_lag; size_t pitch_inv_lag;
{ {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@ -125,10 +133,17 @@ TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
class CheckLowerPitchPeriodsAndComputePitchGainTest class CheckLowerPitchPeriodsAndComputePitchGainTest
: public ::testing::Test, : public ::testing::Test,
public ::testing::WithParamInterface< public ::testing::WithParamInterface<std::tuple<
std::tuple<int, int, float, int, float>> {}; /*initial_pitch_period=*/int,
/*prev_pitch_period=*/int,
/*prev_pitch_gain=*/float,
/*expected_pitch_period=*/int,
/*expected_pitch_gain=*/float>> {};
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { // Checks that the computed pitch period is bit-exact and that the computed
// pitch gain is within tolerance given test input data.
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
PeriodBitExactnessGainWithinTolerance) {
const auto params = GetParam(); const auto params = GetParam();
const int initial_pitch_period = std::get<0>(params); const int initial_pitch_period = std::get<0>(params);
const int prev_pitch_period = std::get<1>(params); const int prev_pitch_period = std::get<1>(params);
@ -147,48 +162,49 @@ TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
} }
} }
INSTANTIATE_TEST_SUITE_P(RnnVadTest, INSTANTIATE_TEST_SUITE_P(
CheckLowerPitchPeriodsAndComputePitchGainTest, RnnVadTest,
::testing::Values(std::make_tuple(kTestPitchPeriods[0], CheckLowerPitchPeriodsAndComputePitchGainTest,
kTestPitchPeriods[0], ::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
kTestPitchGains[0], kTestPitchPeriodsLow,
91, kTestPitchGainsLow,
-0.0188608f), 91,
std::make_tuple(kTestPitchPeriods[0], -0.0188608f),
kTestPitchPeriods[0], std::make_tuple(kTestPitchPeriodsLow,
kTestPitchGains[1], kTestPitchPeriodsLow,
91, kTestPitchGainsHigh,
-0.0188608f), 91,
std::make_tuple(kTestPitchPeriods[0], -0.0188608f),
kTestPitchPeriods[1], std::make_tuple(kTestPitchPeriodsLow,
kTestPitchGains[0], kTestPitchPeriodsHigh,
91, kTestPitchGainsLow,
-0.0188608f), 91,
std::make_tuple(kTestPitchPeriods[0], -0.0188608f),
kTestPitchPeriods[1], std::make_tuple(kTestPitchPeriodsLow,
kTestPitchGains[1], kTestPitchPeriodsHigh,
91, kTestPitchGainsHigh,
-0.0188608f), 91,
std::make_tuple(kTestPitchPeriods[1], -0.0188608f),
kTestPitchPeriods[0], std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchGains[0], kTestPitchPeriodsLow,
475, kTestPitchGainsLow,
-0.0904344f), 475,
std::make_tuple(kTestPitchPeriods[1], -0.0904344f),
kTestPitchPeriods[0], std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchGains[1], kTestPitchPeriodsLow,
475, kTestPitchGainsHigh,
-0.0904344f), 475,
std::make_tuple(kTestPitchPeriods[1], -0.0904344f),
kTestPitchPeriods[1], std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchGains[0], kTestPitchPeriodsHigh,
475, kTestPitchGainsLow,
-0.0904344f), 475,
std::make_tuple(kTestPitchPeriods[1], -0.0904344f),
kTestPitchPeriods[1], std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchGains[1], kTestPitchPeriodsHigh,
475, kTestPitchGainsHigh,
-0.0904344f))); 475,
-0.0904344f)));
} // namespace test } // namespace test
} // namespace rnn_vad } // namespace rnn_vad

View file

@ -12,7 +12,8 @@
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <array> #include <algorithm>
#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@ -23,11 +24,13 @@ namespace webrtc {
namespace rnn_vad { namespace rnn_vad {
namespace test { namespace test {
// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed. // Checks that the computed pitch period is bit-exact and that the computed
TEST(RnnVadTest, PitchSearchBitExactness) { // pitch gain is within tolerance given test input data.
TEST(RnnVadTest, PitchSearchWithinTolerance) {
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
const size_t num_frames = lp_residual_reader.second; const size_t num_frames = std::min(lp_residual_reader.second,
std::array<float, 864> lp_residual; static_cast<size_t>(300)); // Max 3 s.
std::vector<float> lp_residual(kBufSize24kHz);
float expected_pitch_period, expected_pitch_gain; float expected_pitch_period, expected_pitch_gain;
PitchEstimator pitch_estimator; PitchEstimator pitch_estimator;
{ {
@ -38,7 +41,8 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
lp_residual_reader.first->ReadChunk(lp_residual); lp_residual_reader.first->ReadChunk(lp_residual);
lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain); lp_residual_reader.first->ReadValue(&expected_pitch_gain);
PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual); PitchInfo pitch_info =
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period); EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
} }

View file

@ -8,9 +8,7 @@
* be found in the AUTHORS file in the root of the source tree. * be found in the AUTHORS file in the root of the source tree.
*/ */
#include <algorithm>
#include <array> #include <array>
#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@ -63,7 +61,8 @@ void TestGatedRecurrentLayer(
} // namespace } // namespace
// Bit-exactness check for fully connected layers. // Checks that the output of a fully connected layer is within tolerance given
// test input data.
TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
const std::array<int8_t, 1> bias = {-50}; const std::array<int8_t, 1> bias = {-50};
const std::array<int8_t, 24> weights = { const std::array<int8_t, 24> weights = {
@ -106,6 +105,8 @@ TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
} }
} }
// Checks that the output of a GRU layer is within tolerance given test input
// data.
TEST(RnnVadTest, CheckGatedRecurrentLayer) { TEST(RnnVadTest, CheckGatedRecurrentLayer) {
const std::array<int8_t, 12> bias = {96, -99, -81, -114, 49, 119, const std::array<int8_t, 12> bias = {96, -99, -81, -114, 49, 119,
-118, 68, -76, 91, 121, 125}; -118, 68, -76, 91, 121, 125};
@ -139,41 +140,6 @@ TEST(RnnVadTest, CheckGatedRecurrentLayer) {
} }
} }
// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
// Bit-exactness test checking that precomputed frame-wise features lead to the
// expected VAD probabilities.
TEST(RnnVadTest, RnnBitExactness) {
// Init.
auto features_reader = CreateSilenceFlagsFeatureMatrixReader();
auto vad_probs_reader = CreateVadProbsReader();
ASSERT_EQ(features_reader.second, vad_probs_reader.second);
const size_t num_frames = features_reader.second;
// Frame-wise buffers.
float expected_vad_probability;
float is_silence;
std::array<float, kFeatureVectorSize> features;
// Compute VAD probability using the precomputed features.
RnnBasedVad vad;
for (size_t i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i);
// Read frame data.
RTC_CHECK(vad_probs_reader.first->ReadValue(&expected_vad_probability));
// The features file also includes a silence flag for each frame.
RTC_CHECK(features_reader.first->ReadValue(&is_silence));
RTC_CHECK(features_reader.first->ReadChunk(features));
// Compute and check VAD probability.
float vad_probability = vad.ComputeVadProbability(features, is_silence);
ASSERT_TRUE(is_silence == 0.f || is_silence == 1.f);
if (is_silence == 1.f) {
ASSERT_EQ(0.f, expected_vad_probability);
EXPECT_EQ(0.f, vad_probability);
} else {
EXPECT_NEAR(expected_vad_probability, vad_probability, 3e-6f);
}
}
}
} // namespace test } // namespace test
} // namespace rnn_vad } // namespace rnn_vad
} // namespace webrtc } // namespace webrtc

View file

@ -9,6 +9,7 @@
*/ */
#include <array> #include <array>
#include <string>
#include <vector> #include <vector>
#include "common_audio/resampler/push_sinc_resampler.h" #include "common_audio/resampler/push_sinc_resampler.h"
@ -43,8 +44,68 @@ void DumpPerfStats(size_t num_samples,
RTC_LOG(LS_INFO) << "speed: " << speed << "x"; RTC_LOG(LS_INFO) << "speed: " << speed << "x";
} }
// When the RNN VAD model is updated and the expected output changes, set the
// constant below to true in order to write new expected output binary files.
constexpr bool kWriteComputedOutputToFile = false;
} // namespace } // namespace
// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
// when the expected output files are re-exported.
TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
ASSERT_FALSE(kWriteComputedOutputToFile)
<< "Cannot land if kWriteComputedOutput is true.";
}
// Checks that the computed VAD probability for a test input sequence sampled at
// 48 kHz is within tolerance.
TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
// Init resampler, feature extractor and RNN.
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
FeaturesExtractor features_extractor;
RnnBasedVad rnn_vad;
// Init input samples and expected output readers.
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
auto expected_vad_prob_reader = CreateVadProbsReader();
// Input length.
const size_t num_frames = samples_reader.second;
ASSERT_GE(expected_vad_prob_reader.second, num_frames);
// Init buffers.
std::vector<float> samples_48k(kFrameSize10ms48kHz);
std::vector<float> samples_24k(kFrameSize10ms24kHz);
std::vector<float> feature_vector(kFeatureVectorSize);
std::vector<float> computed_vad_prob(num_frames);
std::vector<float> expected_vad_prob(num_frames);
// Read expected output.
ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
// Compute VAD probabilities on the downsampled input.
float cumulative_error = 0.f;
for (size_t i = 0; i < num_frames; ++i) {
samples_reader.first->ReadChunk(samples_48k);
decimator.Resample(samples_48k.data(), samples_48k.size(),
samples_24k.data(), samples_24k.size());
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
{samples_24k.data(), kFrameSize10ms24kHz},
{feature_vector.data(), kFeatureVectorSize});
computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
{feature_vector.data(), kFeatureVectorSize}, is_silence);
EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
}
// Check average error.
EXPECT_LT(cumulative_error / num_frames, 1e-4f);
if (kWriteComputedOutputToFile) {
BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
vad_prob_writer.WriteChunk(computed_vad_prob);
}
}
// Performance test for the RNN VAD (pre-fetching and downsampling are // Performance test for the RNN VAD (pre-fetching and downsampling are
// excluded). Keep disabled and only enable locally to measure performance as // excluded). Keep disabled and only enable locally to measure performance as
// follows: // follows:

View file

@ -85,14 +85,18 @@ TEST(RnnVadTest, DISABLED_TestOpusScaleWeights) {
} }
} }
// Checks that the computed band-wise auto-correlation is non-negative for a
// simple input vector of FFT coefficients.
TEST(RnnVadTest, SpectralCorrelatorValidOutput) { TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
SpectralCorrelator e; // Input: vector of (1, 1j) values.
Pffft fft(kFrameSize20ms24kHz, Pffft::FftType::kReal); Pffft fft(kFrameSize20ms24kHz, Pffft::FftType::kReal);
auto in = fft.CreateBuffer(); auto in = fft.CreateBuffer();
std::array<float, kOpusBands24kHz> out; std::array<float, kOpusBands24kHz> out;
auto in_view = in->GetView(); auto in_view = in->GetView();
std::fill(in_view.begin(), in_view.end(), 1.f); std::fill(in_view.begin(), in_view.end(), 1.f);
in_view[1] = 0.f; // Nyquist frequency. in_view[1] = 0.f; // Nyquist frequency.
// Compute and check output.
SpectralCorrelator e;
e.ComputeAutoCorrelation(in_view, out); e.ComputeAutoCorrelation(in_view, out);
for (size_t i = 0; i < kOpusBands24kHz; ++i) { for (size_t i = 0; i < kOpusBands24kHz; ++i) {
SCOPED_TRACE(i); SCOPED_TRACE(i);
@ -100,6 +104,8 @@ TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
} }
} }
// Checks that the computed smoothed log magnitude spectrum is within tolerance
// given hard-coded test input data.
TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) { TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) {
constexpr std::array<float, kNumBands> input = { constexpr std::array<float, kNumBands> input = {
{86.060539245605f, 275.668334960938f, 43.406528472900f, 6.541896820068f, {86.060539245605f, 275.668334960938f, 43.406528472900f, 6.541896820068f,
@ -124,7 +130,9 @@ TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) {
} }
} }
TEST(RnnVadTest, ComputeDctBitExactness) { // Checks that the computed DCT is within tolerance given hard-coded test input
// data.
TEST(RnnVadTest, ComputeDctWithinTolerance) {
constexpr std::array<float, kNumBands> input = { constexpr std::array<float, kNumBands> input = {
{0.232155621052f, 0.678957760334f, 0.220818966627f, -0.077363930643f, {0.232155621052f, 0.678957760334f, 0.220818966627f, -0.077363930643f,
-0.559227049351f, 0.432545185089f, 0.353900641203f, 0.398993015289f, -0.559227049351f, 0.432545185089f, 0.353900641203f, 0.398993015289f,

View file

@ -67,6 +67,8 @@ constexpr float kInitialFeatureVal = -9999.f;
} // namespace } // namespace
// Checks that silence is detected when the input signal is 0 and that the
// feature vector is written only if the input signal is not tagged as silence.
TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) { TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
// Initialize. // Initialize.
SpectralFeaturesExtractor sfe; SpectralFeaturesExtractor sfe;
@ -108,9 +110,10 @@ TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
[](float x) { return x == kInitialFeatureVal; })); [](float x) { return x == kInitialFeatureVal; }));
} }
// When the input signal does not change, the cepstral coefficients average does // Feeds a constant input signal and checks that:
// not change and the derivatives are zero. Similarly, the cepstral variability // - the cepstral coefficients average does not change;
// score does not change either. // - the derivatives are zero;
// - the cepstral variability score does not change.
TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) { TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
// Initialize. // Initialize.
SpectralFeaturesExtractor sfe; SpectralFeaturesExtractor sfe;

View file

@ -46,13 +46,6 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
} }
} }
std::unique_ptr<BinaryFileReader<float>> CreatePitchSearchTestDataReader() {
constexpr size_t cols = 1396;
return absl::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
cols);
}
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t> std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length) { CreatePcmSamplesReader(const size_t frame_length) {
auto ptr = absl::make_unique<BinaryFileReader<int16_t, float>>( auto ptr = absl::make_unique<BinaryFileReader<int16_t, float>>(
@ -78,25 +71,6 @@ ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)}; rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
} }
ReaderPairType CreateFftCoeffsReader() {
constexpr size_t num_fft_points = 481;
constexpr size_t row_size = 2 * num_fft_points; // Real and imaginary values.
auto ptr = absl::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/fft", "dat"),
num_fft_points);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), row_size)};
}
ReaderPairType CreateSilenceFlagsFeatureMatrixReader() {
constexpr size_t feature_vector_size = 42;
auto ptr = absl::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/sil_features", "dat"),
feature_vector_size);
// Features and silence flag.
return {std::move(ptr),
rtc::CheckedDivExact(ptr->data_length(), feature_vector_size + 1)};
}
ReaderPairType CreateVadProbsReader() { ReaderPairType CreateVadProbsReader() {
auto ptr = absl::make_unique<BinaryFileReader<float>>( auto ptr = absl::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat")); test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
@ -104,23 +78,26 @@ ReaderPairType CreateVadProbsReader() {
} }
PitchTestData::PitchTestData() { PitchTestData::PitchTestData() {
auto test_data_reader = CreatePitchSearchTestDataReader(); BinaryFileReader<float> test_data_reader(
test_data_reader->ReadChunk(test_data_); ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
static_cast<size_t>(1396));
test_data_reader.ReadChunk(test_data_);
} }
PitchTestData::~PitchTestData() = default; PitchTestData::~PitchTestData() = default;
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView() { rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
const {
return {test_data_.data(), kBufSize24kHz}; return {test_data_.data(), kBufSize24kHz};
} }
rtc::ArrayView<const float, kNumPitchBufSquareEnergies> rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
PitchTestData::GetPitchBufSquareEnergiesView() { PitchTestData::GetPitchBufSquareEnergiesView() const {
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
} }
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs> rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
PitchTestData::GetPitchBufAutoCorrCoeffsView() { PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
kNumPitchBufAutoCorrCoeffs}; kNumPitchBufAutoCorrCoeffs};
} }

View file

@ -17,6 +17,7 @@
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -46,11 +47,10 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
template <typename T, typename D = T> template <typename T, typename D = T>
class BinaryFileReader { class BinaryFileReader {
public: public:
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 1) explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
: is_(file_path, std::ios::binary | std::ios::ate), : is_(file_path, std::ios::binary | std::ios::ate),
data_length_(is_.tellg() / sizeof(T)), data_length_(is_.tellg() / sizeof(T)),
chunk_size_(chunk_size) { chunk_size_(chunk_size) {
RTC_CHECK_LT(0, chunk_size_);
RTC_CHECK(is_); RTC_CHECK(is_);
SeekBeginning(); SeekBeginning();
buf_.resize(chunk_size_); buf_.resize(chunk_size_);
@ -69,9 +69,11 @@ class BinaryFileReader {
} }
return is_.gcount() == sizeof(T); return is_.gcount() == sizeof(T);
} }
// If |chunk_size| was specified in the ctor, it will check that the size of
// |dst| equals |chunk_size|.
bool ReadChunk(rtc::ArrayView<D> dst) { bool ReadChunk(rtc::ArrayView<D> dst) {
RTC_DCHECK_EQ(chunk_size_, dst.size()); RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
const std::streamsize bytes_to_read = chunk_size_ * sizeof(T); const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, D>::value) { if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read); is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else { } else {
@ -91,9 +93,26 @@ class BinaryFileReader {
std::vector<T> buf_; std::vector<T> buf_;
}; };
// Writer for binary files.
template <typename T>
class BinaryFileWriter {
public:
explicit BinaryFileWriter(const std::string& file_path)
: os_(file_path, std::ios::binary) {}
BinaryFileWriter(const BinaryFileWriter&) = delete;
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
~BinaryFileWriter() = default;
static_assert(std::is_arithmetic<T>::value, "");
void WriteChunk(rtc::ArrayView<const T> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(T);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
};
// Factories for resource file readers. // Factories for resource file readers.
// Creates a reader for the pitch search test data.
std::unique_ptr<BinaryFileReader<float>> CreatePitchSearchTestDataReader();
// The functions below return a pair where the first item is a reader unique // The functions below return a pair where the first item is a reader unique
// pointer and the second the number of chunks that can be read from the file. // pointer and the second the number of chunks that can be read from the file.
// Creates a reader for the PCM samples that casts from S16 to float and reads // Creates a reader for the PCM samples that casts from S16 to float and reads
@ -107,12 +126,6 @@ CreatePitchBuffer24kHzReader();
// and gain values. // and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t> std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateLpResidualAndPitchPeriodGainReader(); CreateLpResidualAndPitchPeriodGainReader();
// Creates a reader for the FFT coefficients.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateFftCoeffsReader();
// Creates a reader for the silence flags and the feature matrix.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateSilenceFlagsFeatureMatrixReader();
// Creates a reader for the VAD probabilities. // Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t> std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateVadProbsReader(); CreateVadProbsReader();
@ -128,11 +141,11 @@ class PitchTestData {
public: public:
PitchTestData(); PitchTestData();
~PitchTestData(); ~PitchTestData();
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView(); rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
rtc::ArrayView<const float, kNumPitchBufSquareEnergies> rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
GetPitchBufSquareEnergiesView(); GetPitchBufSquareEnergiesView() const;
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs> rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
GetPitchBufAutoCorrCoeffsView(); GetPitchBufAutoCorrCoeffsView() const;
private: private:
std::array<float, kPitchTestDataSize> test_data_; std::array<float, kPitchTestDataSize> test_data_;

View file

@ -1 +0,0 @@
e62364d35abd123663bfc800fa233071d6d7fffd

View file

@ -1 +0,0 @@
e0a92782c2903be9da10385d924d34e8bf212d5e

View file

@ -1 +1 @@
05735ede0b457318e307d12f5acfd11bbbbd0afd 68640327266262c3fe047ec7f07a46a355ff90b9