mirror of
https://github.com/mollyim/webrtc.git
synced 2025-05-13 13:50:40 +01:00

This CL adds a component in the TS implementation to return a delayed version of the voice probability values observed when `Suppress()` is called. That is needed in order to temporally align the voice probability values to the processed audio since TS adds algorithmic delay. Bug: webrtc:13663 Change-Id: I5041ace3939d2ce7ba084ae703428e66f1aa06be Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/255860 Reviewed-by: Hanna Silen <silen@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/main@{#36496}
175 lines
6.4 KiB
C++
175 lines
6.4 KiB
C++
/*
|
|
* Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
|
|
*
|
|
* Use of this source code is governed by a BSD-style license
|
|
* that can be found in the LICENSE file in the root of the source
|
|
* tree. An additional intellectual property rights grant can be found
|
|
* in the file PATENTS. All contributing project authors may
|
|
* be found in the AUTHORS file in the root of the source tree.
|
|
*/
|
|
|
|
#include "modules/audio_processing/transient/transient_suppressor.h"
|
|
|
|
#include <vector>
|
|
|
|
#include "absl/types/optional.h"
|
|
#include "modules/audio_processing/transient/common.h"
|
|
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
|
|
#include "test/gtest.h"
|
|
|
|
namespace webrtc {
|
|
namespace {
|
|
constexpr int kMono = 1;
|
|
|
|
// Returns the index of the first non-zero sample in `samples` or an unspecified
|
|
// value if no value is zero.
|
|
absl::optional<int> FindFirstNonZeroSample(const std::vector<float>& samples) {
|
|
for (size_t i = 0; i < samples.size(); ++i) {
|
|
if (samples[i] != 0.0f) {
|
|
return i;
|
|
}
|
|
}
|
|
return absl::nullopt;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
class TransientSuppressorVadModeParametrization
|
|
: public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
|
|
|
|
TEST_P(TransientSuppressorVadModeParametrization,
|
|
TypingDetectionLogicWorksAsExpectedForMono) {
|
|
TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
|
|
ts::kSampleRate16kHz, kMono);
|
|
|
|
// Each key-press enables detection.
|
|
EXPECT_FALSE(ts.detection_enabled_);
|
|
ts.UpdateKeypress(true);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
|
|
// It takes four seconds without any key-press to disable the detection
|
|
for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
|
|
ts.UpdateKeypress(false);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
}
|
|
ts.UpdateKeypress(false);
|
|
EXPECT_FALSE(ts.detection_enabled_);
|
|
|
|
// Key-presses that are more than a second apart from each other don't enable
|
|
// suppression.
|
|
for (int i = 0; i < 100; ++i) {
|
|
EXPECT_FALSE(ts.suppression_enabled_);
|
|
ts.UpdateKeypress(true);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
EXPECT_FALSE(ts.suppression_enabled_);
|
|
for (int time_ms = 0; time_ms < 990; time_ms += ts::kChunkSizeMs) {
|
|
ts.UpdateKeypress(false);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
EXPECT_FALSE(ts.suppression_enabled_);
|
|
}
|
|
ts.UpdateKeypress(false);
|
|
}
|
|
|
|
// Two consecutive key-presses is enough to enable the suppression.
|
|
ts.UpdateKeypress(true);
|
|
EXPECT_FALSE(ts.suppression_enabled_);
|
|
ts.UpdateKeypress(true);
|
|
EXPECT_TRUE(ts.suppression_enabled_);
|
|
|
|
// Key-presses that are less than a second apart from each other don't disable
|
|
// detection nor suppression.
|
|
for (int i = 0; i < 100; ++i) {
|
|
for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
|
|
ts.UpdateKeypress(false);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
EXPECT_TRUE(ts.suppression_enabled_);
|
|
}
|
|
ts.UpdateKeypress(true);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
EXPECT_TRUE(ts.suppression_enabled_);
|
|
}
|
|
|
|
// It takes four seconds without any key-press to disable the detection and
|
|
// suppression.
|
|
for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
|
|
ts.UpdateKeypress(false);
|
|
EXPECT_TRUE(ts.detection_enabled_);
|
|
EXPECT_TRUE(ts.suppression_enabled_);
|
|
}
|
|
for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
|
|
ts.UpdateKeypress(false);
|
|
EXPECT_FALSE(ts.detection_enabled_);
|
|
EXPECT_FALSE(ts.suppression_enabled_);
|
|
}
|
|
}
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
TransientSuppressorImplTest,
|
|
TransientSuppressorVadModeParametrization,
|
|
::testing::Values(TransientSuppressor::VadMode::kDefault,
|
|
TransientSuppressor::VadMode::kRnnVad,
|
|
TransientSuppressor::VadMode::kNoVad));
|
|
|
|
class TransientSuppressorSampleRateParametrization
|
|
: public ::testing::TestWithParam<int> {};
|
|
|
|
// Checks that voice probability and processed audio data are temporally aligned
|
|
// after `Suppress()` is called.
|
|
TEST_P(TransientSuppressorSampleRateParametrization,
|
|
CheckAudioAndVoiceProbabilityTemporallyAligned) {
|
|
const int sample_rate_hz = GetParam();
|
|
TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault,
|
|
sample_rate_hz,
|
|
/*detection_rate_hz=*/sample_rate_hz, kMono);
|
|
|
|
const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000;
|
|
std::vector<float> frame(frame_size);
|
|
|
|
constexpr int kMaxAttempts = 3;
|
|
for (int i = 0; i < kMaxAttempts; ++i) {
|
|
SCOPED_TRACE(i);
|
|
|
|
// Call `Suppress()` on frames of non-zero audio samples.
|
|
std::fill(frame.begin(), frame.end(), 1000.0f);
|
|
float delayed_voice_probability = ts.Suppress(
|
|
frame.data(), frame.size(), kMono, /*detection_data=*/nullptr,
|
|
/*detection_length=*/frame_size, /*reference_data=*/nullptr,
|
|
/*reference_length=*/frame_size, /*voice_probability=*/1.0f,
|
|
/*key_pressed=*/false);
|
|
|
|
// Detect the algorithmic delay of `TransientSuppressorImpl`.
|
|
absl::optional<int> frame_delay = FindFirstNonZeroSample(frame);
|
|
|
|
// Check that the delayed voice probability is delayed according to the
|
|
// measured delay.
|
|
if (frame_delay.has_value()) {
|
|
if (*frame_delay == 0) {
|
|
// When the delay is a multiple integer of the frame duration,
|
|
// `Suppress()` returns a copy of a previously observed voice
|
|
// probability value.
|
|
EXPECT_EQ(delayed_voice_probability, 1.0f);
|
|
} else {
|
|
// Instead, when the delay is fractional, `Suppress()` returns an
|
|
// interpolated value. Since the exact value depends on the
|
|
// interpolation method, we only check that the delayed voice
|
|
// probability is not zero as it must converge towards the previoulsy
|
|
// observed value.
|
|
EXPECT_GT(delayed_voice_probability, 0.0f);
|
|
}
|
|
break;
|
|
} else {
|
|
// The algorithmic delay is longer than the duration of a single frame.
|
|
// Until the delay is detected, the delayed voice probability is zero.
|
|
EXPECT_EQ(delayed_voice_probability, 0.0f);
|
|
}
|
|
}
|
|
}
|
|
|
|
INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest,
|
|
TransientSuppressorSampleRateParametrization,
|
|
::testing::Values(ts::kSampleRate8kHz,
|
|
ts::kSampleRate16kHz,
|
|
ts::kSampleRate32kHz,
|
|
ts::kSampleRate48kHz));
|
|
|
|
} // namespace webrtc
|