APM: fix TS initialization bugs with WebRTC-Audio-GainController2

When the `WebRTC-Audio-GainController2` field trial is used, the
initial APM configuration is adjusted depending on its original
values and the field trial parameters.

This CL fixes two cases when the code crashes:
1. when, in the initial APM config, AGC1 is enabled, AGC2 is
   disabled and TS is enabled
2. when the initial APM sample rate is different from the
   capture one and the VAD APM sub-module is not re-initialized

This CL also improves the unit tests coverage and it has been
tested offline to check that the VAD sub-module is created only
when expected and that AGC2 uses its internal VAD when expected.
The tests ran on a few Wav files with different sample rates and
one AEC dump and on 16 different APM and field trial
configurations.

Bug: chromium:1407341, b/265112132
Change-Id: I7cc267ea81cb02be92c1f37f273b7ae93b6e4634
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/290988
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Olga Sharonova <olka@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39118}
This commit is contained in:
Alessio Bazzica 2023-01-16 20:19:48 +01:00 committed by WebRTC LUCI CQ
parent f7a46e55cb
commit 40b5bd72d0
5 changed files with 125 additions and 131 deletions

View file

@ -506,13 +506,16 @@ AudioProcessing::Config AudioProcessingImpl::AdjustConfig(
!config.gain_controller2.enabled; !config.gain_controller2.enabled;
const bool one_and_only_one_input_volume_controller = const bool one_and_only_one_input_volume_controller =
hybrid_agc_config_detected != full_agc1_config_detected; hybrid_agc_config_detected != full_agc1_config_detected;
const bool agc2_input_volume_controller_enabled =
config.gain_controller2.enabled &&
config.gain_controller2.input_volume_controller.enabled;
if (!one_and_only_one_input_volume_controller || if (!one_and_only_one_input_volume_controller ||
config.gain_controller2.input_volume_controller.enabled) { agc2_input_volume_controller_enabled) {
RTC_LOG(LS_ERROR) << "Cannot adjust AGC config (precondition failed)"; RTC_LOG(LS_ERROR) << "Cannot adjust AGC config (precondition failed)";
if (!one_and_only_one_input_volume_controller) if (!one_and_only_one_input_volume_controller)
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "One and only one input volume controller must be enabled."; << "One and only one input volume controller must be enabled.";
if (config.gain_controller2.input_volume_controller.enabled) if (agc2_input_volume_controller_enabled)
RTC_LOG(LS_ERROR) RTC_LOG(LS_ERROR)
<< "The AGC2 input volume controller must be disabled."; << "The AGC2 input volume controller must be disabled.";
} else { } else {
@ -530,19 +533,18 @@ AudioProcessing::Config AudioProcessingImpl::AdjustConfig(
return adjusted_config; return adjusted_config;
} }
TransientSuppressor::VadMode AudioProcessingImpl::GetTransientSuppressorVadMode( bool AudioProcessingImpl::UseApmVadSubModule(
const absl::optional<AudioProcessingImpl::GainController2ExperimentParams>& const AudioProcessing::Config& config,
params) { const absl::optional<GainController2ExperimentParams>& experiment_params) {
if (params.has_value() && params->agc2_config.has_value() && // The VAD as an APM sub-module is needed only in one case, that is when TS
!params->disallow_transient_suppressor_usage) { // and AGC2 are both enabled and when the AGC2 experiment is running and its
// When the experiment is active, the gain control switches to AGC2 and TS // parameters require to fully switch the gain control to AGC2.
// can be active, use the RNN VAD to control TS. This choice will also return config.transient_suppression.enabled &&
// disable the internal RNN VAD in AGC2. config.gain_controller2.enabled &&
return TransientSuppressor::VadMode::kRnnVad; (config.gain_controller2.input_volume_controller.enabled ||
} config.gain_controller2.adaptive_digital.enabled) &&
// If TS is disabled, the returned value does not matter. If enabled, use the experiment_params.has_value() &&
// default VAD. experiment_params->agc2_config.has_value();
return TransientSuppressor::VadMode::kDefault;
} }
AudioProcessingImpl::SubmoduleStates::SubmoduleStates( AudioProcessingImpl::SubmoduleStates::SubmoduleStates(
@ -663,8 +665,7 @@ AudioProcessingImpl::AudioProcessingImpl(
use_setup_specific_default_aec3_config_( use_setup_specific_default_aec3_config_(
UseSetupSpecificDefaultAec3Congfig()), UseSetupSpecificDefaultAec3Congfig()),
gain_controller2_experiment_params_(GetGainController2ExperimentParams()), gain_controller2_experiment_params_(GetGainController2ExperimentParams()),
transient_suppressor_vad_mode_( transient_suppressor_vad_mode_(TransientSuppressor::VadMode::kDefault),
GetTransientSuppressorVadMode(gain_controller2_experiment_params_)),
capture_runtime_settings_(RuntimeSettingQueueSize()), capture_runtime_settings_(RuntimeSettingQueueSize()),
render_runtime_settings_(RuntimeSettingQueueSize()), render_runtime_settings_(RuntimeSettingQueueSize()),
capture_runtime_settings_enqueuer_(&capture_runtime_settings_), capture_runtime_settings_enqueuer_(&capture_runtime_settings_),
@ -809,8 +810,8 @@ void AudioProcessingImpl::InitializeLocked() {
InitializeHighPassFilter(true); InitializeHighPassFilter(true);
InitializeResidualEchoDetector(); InitializeResidualEchoDetector();
InitializeEchoController(); InitializeEchoController();
InitializeGainController2(/*config_has_changed=*/true); InitializeGainController2();
InitializeVoiceActivityDetector(/*config_has_changed=*/true); InitializeVoiceActivityDetector();
InitializeNoiseSuppressor(); InitializeNoiseSuppressor();
InitializeAnalyzer(); InitializeAnalyzer();
InitializePostProcessor(); InitializePostProcessor();
@ -977,8 +978,12 @@ void AudioProcessingImpl::ApplyConfig(const AudioProcessing::Config& config) {
config_.gain_controller2 = AudioProcessing::Config::GainController2(); config_.gain_controller2 = AudioProcessing::Config::GainController2();
} }
InitializeGainController2(agc2_config_changed); if (agc2_config_changed || ts_config_changed) {
InitializeVoiceActivityDetector(agc2_config_changed); // AGC2 also depends on TS because of the possible dependency on the APM VAD
// sub-module.
InitializeGainController2();
InitializeVoiceActivityDetector();
}
if (pre_amplifier_config_changed || gain_adjustment_config_changed) { if (pre_amplifier_config_changed || gain_adjustment_config_changed) {
InitializeCaptureLevelsAdjuster(); InitializeCaptureLevelsAdjuster();
@ -2144,10 +2149,20 @@ bool AudioProcessingImpl::UpdateActiveSubmoduleStates() {
} }
void AudioProcessingImpl::InitializeTransientSuppressor() { void AudioProcessingImpl::InitializeTransientSuppressor() {
// Choose the VAD mode for TS and detect a VAD mode change.
const TransientSuppressor::VadMode previous_vad_mode =
transient_suppressor_vad_mode_;
transient_suppressor_vad_mode_ = TransientSuppressor::VadMode::kDefault;
if (UseApmVadSubModule(config_, gain_controller2_experiment_params_)) {
transient_suppressor_vad_mode_ = TransientSuppressor::VadMode::kRnnVad;
}
const bool vad_mode_changed =
previous_vad_mode != transient_suppressor_vad_mode_;
if (config_.transient_suppression.enabled && if (config_.transient_suppression.enabled &&
!constants_.transient_suppressor_forced_off) { !constants_.transient_suppressor_forced_off) {
// Attempt to create a transient suppressor, if one is not already created. // Attempt to create a transient suppressor, if one is not already created.
if (!submodules_.transient_suppressor) { if (!submodules_.transient_suppressor || vad_mode_changed) {
submodules_.transient_suppressor = CreateTransientSuppressor( submodules_.transient_suppressor = CreateTransientSuppressor(
submodule_creation_overrides_, transient_suppressor_vad_mode_, submodule_creation_overrides_, transient_suppressor_vad_mode_,
proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate, proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
@ -2341,54 +2356,48 @@ void AudioProcessingImpl::InitializeGainController1() {
capture_.capture_output_used); capture_.capture_output_used);
} }
void AudioProcessingImpl::InitializeGainController2(bool config_has_changed) { void AudioProcessingImpl::InitializeGainController2() {
if (!config_has_changed) {
return;
}
if (!config_.gain_controller2.enabled) { if (!config_.gain_controller2.enabled) {
submodules_.gain_controller2.reset(); submodules_.gain_controller2.reset();
return; return;
} }
if (!submodules_.gain_controller2 || config_has_changed) { // Override the input volume controller configuration if the AGC2 experiment
const bool use_internal_vad = // is running and its parameters require to fully switch the gain control to
transient_suppressor_vad_mode_ != TransientSuppressor::VadMode::kRnnVad; // AGC2.
const bool input_volume_controller_config_overridden = const bool input_volume_controller_config_overridden =
gain_controller2_experiment_params_.has_value() && gain_controller2_experiment_params_.has_value() &&
gain_controller2_experiment_params_->agc2_config.has_value(); gain_controller2_experiment_params_->agc2_config.has_value();
const InputVolumeController::Config input_volume_controller_config = const InputVolumeController::Config input_volume_controller_config =
input_volume_controller_config_overridden input_volume_controller_config_overridden
? gain_controller2_experiment_params_->agc2_config ? gain_controller2_experiment_params_->agc2_config
->input_volume_controller ->input_volume_controller
: InputVolumeController::Config{}; : InputVolumeController::Config{};
submodules_.gain_controller2 = std::make_unique<GainController2>( // If the APM VAD sub-module is not used, let AGC2 use its internal VAD.
config_.gain_controller2, input_volume_controller_config, const bool use_internal_vad =
proc_fullband_sample_rate_hz(), num_proc_channels(), use_internal_vad); !UseApmVadSubModule(config_, gain_controller2_experiment_params_);
submodules_.gain_controller2->SetCaptureOutputUsed( submodules_.gain_controller2 = std::make_unique<GainController2>(
capture_.capture_output_used); config_.gain_controller2, input_volume_controller_config,
} proc_fullband_sample_rate_hz(), num_proc_channels(), use_internal_vad);
submodules_.gain_controller2->SetCaptureOutputUsed(
capture_.capture_output_used);
} }
void AudioProcessingImpl::InitializeVoiceActivityDetector( void AudioProcessingImpl::InitializeVoiceActivityDetector() {
bool config_has_changed) { if (!UseApmVadSubModule(config_, gain_controller2_experiment_params_)) {
if (!config_has_changed) {
return;
}
const bool use_vad =
transient_suppressor_vad_mode_ == TransientSuppressor::VadMode::kRnnVad &&
config_.gain_controller2.enabled &&
(config_.gain_controller2.adaptive_digital.enabled ||
config_.gain_controller2.input_volume_controller.enabled);
if (!use_vad) {
submodules_.voice_activity_detector.reset(); submodules_.voice_activity_detector.reset();
return; return;
} }
if (!submodules_.voice_activity_detector || config_has_changed) {
if (!submodules_.voice_activity_detector) {
RTC_DCHECK(!!submodules_.gain_controller2); RTC_DCHECK(!!submodules_.gain_controller2);
// TODO(bugs.webrtc.org/13663): Cache CPU features in APM and use here. // TODO(bugs.webrtc.org/13663): Cache CPU features in APM and use here.
submodules_.voice_activity_detector = submodules_.voice_activity_detector =
std::make_unique<VoiceActivityDetectorWrapper>( std::make_unique<VoiceActivityDetectorWrapper>(
submodules_.gain_controller2->GetCpuFeatures(), submodules_.gain_controller2->GetCpuFeatures(),
proc_fullband_sample_rate_hz()); proc_fullband_sample_rate_hz());
} else {
submodules_.voice_activity_detector->Initialize(
proc_fullband_sample_rate_hz());
} }
} }

View file

@ -227,10 +227,12 @@ class AudioProcessingImpl : public AudioProcessing {
static AudioProcessing::Config AdjustConfig( static AudioProcessing::Config AdjustConfig(
const AudioProcessing::Config& config, const AudioProcessing::Config& config,
const absl::optional<GainController2ExperimentParams>& experiment_params); const absl::optional<GainController2ExperimentParams>& experiment_params);
static TransientSuppressor::VadMode GetTransientSuppressorVadMode( // Returns true if the APM VAD sub-module should be used.
static bool UseApmVadSubModule(
const AudioProcessing::Config& config,
const absl::optional<GainController2ExperimentParams>& experiment_params); const absl::optional<GainController2ExperimentParams>& experiment_params);
const TransientSuppressor::VadMode transient_suppressor_vad_mode_; TransientSuppressor::VadMode transient_suppressor_vad_mode_;
SwapQueue<RuntimeSetting> capture_runtime_settings_; SwapQueue<RuntimeSetting> capture_runtime_settings_;
SwapQueue<RuntimeSetting> render_runtime_settings_; SwapQueue<RuntimeSetting> render_runtime_settings_;
@ -317,14 +319,15 @@ class AudioProcessingImpl : public AudioProcessing {
void InitializeGainController1() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); void InitializeGainController1() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
void InitializeTransientSuppressor() void InitializeTransientSuppressor()
RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
// Initializes the `GainController2` sub-module. If the sub-module is enabled // Initializes the `GainController2` sub-module. If the sub-module is enabled,
// and `config_has_changed` is true, recreates the sub-module. // recreates it.
void InitializeGainController2(bool config_has_changed) void InitializeGainController2() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
// Initializes the `VoiceActivityDetectorWrapper` sub-module. If the // Initializes the `VoiceActivityDetectorWrapper` sub-module. If the
// sub-module is enabled and `config_has_changed` is true, recreates the // sub-module is enabled, recreates it. Call `InitializeGainController2()`
// sub-module. // first.
void InitializeVoiceActivityDetector(bool config_has_changed) // TODO(bugs.webrtc.org/13663): Remove if TS is removed otherwise remove call
// order requirement - i.e., decouple from `InitializeGainController2()`.
void InitializeVoiceActivityDetector()
RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
void InitializeNoiseSuppressor() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_); void InitializeNoiseSuppressor() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
void InitializeCaptureLevelsAdjuster() void InitializeCaptureLevelsAdjuster()

View file

@ -1287,7 +1287,10 @@ TEST_P(Agc2FieldTrialParametrizedTest,
TEST_P(Agc2FieldTrialParametrizedTest, ProcessSucceedsWithTs) { TEST_P(Agc2FieldTrialParametrizedTest, ProcessSucceedsWithTs) {
AudioProcessing::Config config = GetParam(); AudioProcessing::Config config = GetParam();
config.transient_suppression.enabled = true; if (!config.transient_suppression.enabled) {
GTEST_SKIP() << "TS is disabled, skip.";
}
webrtc::test::ScopedFieldTrials field_trials( webrtc::test::ScopedFieldTrials field_trials(
"WebRTC-Audio-GainController2/Disabled/"); "WebRTC-Audio-GainController2/Disabled/");
auto apm = AudioProcessingBuilder().SetConfig(config).Create(); auto apm = AudioProcessingBuilder().SetConfig(config).Create();
@ -1340,7 +1343,10 @@ TEST_P(Agc2FieldTrialParametrizedTest, ProcessSucceedsWithoutTs) {
TEST_P(Agc2FieldTrialParametrizedTest, TEST_P(Agc2FieldTrialParametrizedTest,
ProcessSucceedsWhenSwitchToFullAgc2WithTs) { ProcessSucceedsWhenSwitchToFullAgc2WithTs) {
AudioProcessing::Config config = GetParam(); AudioProcessing::Config config = GetParam();
config.transient_suppression.enabled = true; if (!config.transient_suppression.enabled) {
GTEST_SKIP() << "TS is disabled, skip.";
}
webrtc::test::ScopedFieldTrials field_trials( webrtc::test::ScopedFieldTrials field_trials(
"WebRTC-Audio-GainController2/Enabled," "WebRTC-Audio-GainController2/Enabled,"
"switch_to_agc2:true," "switch_to_agc2:true,"
@ -1397,15 +1403,34 @@ INSTANTIATE_TEST_SUITE_P(
AudioProcessingImplTest, AudioProcessingImplTest,
Agc2FieldTrialParametrizedTest, Agc2FieldTrialParametrizedTest,
::testing::Values( ::testing::Values(
// Full AGC1. // Full AGC1, TS disabled.
AudioProcessing::Config{ AudioProcessing::Config{
.transient_suppression = {.enabled = false},
.gain_controller1 = .gain_controller1 =
{.enabled = true, {.enabled = true,
.analog_gain_controller = {.enabled = true, .analog_gain_controller = {.enabled = true,
.enable_digital_adaptive = true}}, .enable_digital_adaptive = true}},
.gain_controller2 = {.enabled = false}}, .gain_controller2 = {.enabled = false}},
// Hybrid AGC. // Full AGC1, TS enabled.
AudioProcessing::Config{ AudioProcessing::Config{
.transient_suppression = {.enabled = true},
.gain_controller1 =
{.enabled = true,
.analog_gain_controller = {.enabled = true,
.enable_digital_adaptive = true}},
.gain_controller2 = {.enabled = false}},
// Hybrid AGC, TS disabled.
AudioProcessing::Config{
.transient_suppression = {.enabled = false},
.gain_controller1 =
{.enabled = true,
.analog_gain_controller = {.enabled = true,
.enable_digital_adaptive = false}},
.gain_controller2 = {.enabled = true,
.adaptive_digital = {.enabled = true}}},
// Hybrid AGC, TS enabled.
AudioProcessing::Config{
.transient_suppression = {.enabled = true},
.gain_controller1 = .gain_controller1 =
{.enabled = true, {.enabled = true,
.analog_gain_controller = {.enabled = true, .analog_gain_controller = {.enabled = true,

View file

@ -183,8 +183,13 @@ void GainController2::Process(absl::optional<float> speech_probability,
audio->num_frames()); audio->num_frames());
// Compute speech probability. // Compute speech probability.
if (vad_) { if (vad_) {
// When the VAD component runs, `speech_probability` should not be specified
// because APM should not run the same VAD twice (as an APM sub-module and
// internally in AGC2).
RTC_DCHECK(!speech_probability.has_value());
speech_probability = vad_->Analyze(float_frame); speech_probability = vad_->Analyze(float_frame);
} else if (speech_probability.has_value()) { }
if (speech_probability.has_value()) {
RTC_DCHECK_GE(*speech_probability, 0.0f); RTC_DCHECK_GE(*speech_probability, 0.0f);
RTC_DCHECK_LE(*speech_probability, 1.0f); RTC_DCHECK_LE(*speech_probability, 1.0f);
} }

View file

@ -455,74 +455,26 @@ TEST(GainController2, CheckFinalGainWithAdaptiveDigitalController) {
EXPECT_NEAR(applied_gain_db, kExpectedGainDb, kToleranceDb); EXPECT_NEAR(applied_gain_db, kExpectedGainDb, kToleranceDb);
} }
// Processes a test audio file and checks that the injected speech probability #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
// is ignored when the internal VAD is used. // Checks that `GainController2` crashes in debug mode if it runs its internal
TEST(GainController2, // VAD and the speech probability values are provided by the caller.
CheckInjectedVadProbabilityNotUsedWithAdaptiveDigitalController) { TEST(GainController2DeathTest,
DebugCrashIfUseInternalVadAndSpeechProbabilityGiven) {
constexpr int kSampleRateHz = AudioProcessing::kSampleRate48kHz; constexpr int kSampleRateHz = AudioProcessing::kSampleRate48kHz;
constexpr int kStereo = 2; constexpr int kStereo = 2;
// Create AGC2 enabling only the adaptive digital controller.
Agc2Config config;
config.fixed_digital.gain_db = 0.0f;
config.adaptive_digital.enabled = true;
GainController2 agc2(config, /*input_volume_controller_config=*/{},
kSampleRateHz, kStereo,
/*use_internal_vad=*/true);
GainController2 agc2_reference(config, /*input_volume_controller_config=*/{},
kSampleRateHz, kStereo,
/*use_internal_vad=*/true);
test::InputAudioFile input_file(
test::GetApmCaptureTestVectorFileName(kSampleRateHz),
/*loop_at_end=*/true);
const StreamConfig stream_config(kSampleRateHz, kStereo);
// Init buffers.
constexpr int kFrameDurationMs = 10;
std::vector<float> frame(kStereo * stream_config.num_frames());
AudioBuffer audio_buffer(kSampleRateHz, kStereo, kSampleRateHz, kStereo, AudioBuffer audio_buffer(kSampleRateHz, kStereo, kSampleRateHz, kStereo,
kSampleRateHz, kStereo); kSampleRateHz, kStereo);
AudioBuffer audio_buffer_reference(kSampleRateHz, kStereo, kSampleRateHz, // Create AGC2 so that the interval VAD is also created.
kStereo, kSampleRateHz, kStereo); GainController2 agc2(/*config=*/{.adaptive_digital = {.enabled = true}},
/*input_volume_controller_config=*/{}, kSampleRateHz,
kStereo,
/*use_internal_vad=*/true);
// Simulate. EXPECT_DEATH(agc2.Process(/*speech_probability=*/0.123f,
constexpr float kGainDb = -6.0f; /*input_volume_changed=*/false, &audio_buffer),
const float gain = std::pow(10.0f, kGainDb / 20.0f); "");
constexpr int kDurationMs = 10000;
constexpr int kNumFramesToProcess = kDurationMs / kFrameDurationMs;
constexpr float kSpeechProbabilities[] = {1.0f, 0.3f};
constexpr float kEpsilon = 0.0001f;
bool all_samples_zero = true;
for (int i = 0, j = 0; i < kNumFramesToProcess; ++i, j = 1 - j) {
ReadFloatSamplesFromStereoFile(stream_config.num_frames(),
stream_config.num_channels(), &input_file,
frame);
// Apply a fixed gain to the input audio.
for (float& x : frame) {
x *= gain;
}
test::CopyVectorToAudioBuffer(stream_config, frame, &audio_buffer);
agc2.Process(kSpeechProbabilities[j], /*input_volume_changed=*/false,
&audio_buffer);
test::CopyVectorToAudioBuffer(stream_config, frame,
&audio_buffer_reference);
agc2_reference.Process(/*speech_probability=*/absl::nullopt,
/*input_volume_changed=*/false,
&audio_buffer_reference);
// Check the output buffers.
for (int i = 0; i < kStereo; ++i) {
for (int j = 0; j < static_cast<int>(audio_buffer.num_frames()); ++j) {
all_samples_zero &=
fabs(audio_buffer.channels_const()[i][j]) < kEpsilon;
EXPECT_FLOAT_EQ(audio_buffer.channels_const()[i][j],
audio_buffer_reference.channels_const()[i][j]);
}
}
}
EXPECT_FALSE(all_samples_zero);
} }
#endif
// Processes a test audio file and checks that the injected speech probability // Processes a test audio file and checks that the injected speech probability
// is not ignored when the internal VAD is not used. // is not ignored when the internal VAD is not used.