AEC3: Misadjustment estimator of the linear filter.

In this work the performance of the linear filter is
estimated. The estimation aims at capture situations when the linear
filter is largely over-estimating the echo. In those circumstances,
the linear filter is scaled with the purpose of accelerating its
convergence.

Change-Id: I05ea3739d82838a6f08673432da92125c47943e0
Bug: webrtc:9466,chromium:857426
Reviewed-on: https://webrtc-review.googlesource.com/86133
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Commit-Queue: Jesus de Vicente Pena <devicentepena@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23789}
This commit is contained in:
Jesús de Vicente Peña 2018-06-29 16:35:08 +02:00 committed by Commit Bot
parent 916ec7dadf
commit 2e79d2b398
5 changed files with 119 additions and 1 deletions

View file

@ -616,4 +616,18 @@ void AdaptiveFirFilter::Constrain() {
: 0;
}
void AdaptiveFirFilter::ScaleFilter(float factor) {
for (auto& H : H_) {
for (auto& re : H.re) {
re *= factor;
}
for (auto& im : H.im) {
im *= factor;
}
}
for (auto& h : h_) {
h *= factor;
}
}
} // namespace webrtc

View file

@ -143,6 +143,9 @@ class AdaptiveFirFilter {
h_.resize(current_size);
}
// Scale the filter impulse response and spectrum by a factor.
void ScaleFilter(float factor);
private:
// Constrain the filter partitions in a cyclic manner.
void Constrain();

View file

@ -27,6 +27,10 @@ bool EnableAdaptationDuringSaturation() {
return !field_trial::IsEnabled("WebRTC-Aec3RapidAgcGainRecoveryKillSwitch");
}
bool EnableMisadjustmentEstimator() {
return !field_trial::IsEnabled("WebRTC-Aec3MisadjustmentEstimatorKillSwitch");
}
void PredictionError(const Aec3Fft& fft,
const FftData& S,
rtc::ArrayView<const float> y,
@ -72,6 +76,7 @@ Subtractor::Subtractor(const EchoCanceller3Config& config,
optimization_(optimization),
config_(config),
adaptation_during_saturation_(EnableAdaptationDuringSaturation()),
enable_misadjustment_estimator_(EnableMisadjustmentEstimator()),
main_filter_(config_.filter.main.length_blocks,
config_.filter.main_initial.length_blocks,
config.filter.config_change_duration_blocks,
@ -182,6 +187,15 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
main_filter_once_converged_ || main_filter_converged_;
main_filter_diverged_ = e2_main > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize;
if (enable_misadjustment_estimator_) {
filter_misadjustment_estimator_.Update(e2_main, y2);
if (filter_misadjustment_estimator_.IsAdjustmentNeeded()) {
float scale = filter_misadjustment_estimator_.GetMisadjustment();
main_filter_.ScaleFilter(scale);
output->ScaleOutputMainFilter(scale);
filter_misadjustment_estimator_.Reset();
}
}
// Compute spectra for future use.
E_shadow.Spectrum(optimization_, output->E2_shadow);
E_main.Spectrum(optimization_, output->E2_main);
@ -206,7 +220,7 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
data_dumper_->DumpRaw("aec3_subtractor_G_shadow", G.re);
data_dumper_->DumpRaw("aec3_subtractor_G_shadow", G.im);
filter_misadjustment_estimator_.Dump(data_dumper_);
DumpFilters();
if (adaptation_during_saturation_) {
@ -215,4 +229,39 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
}
}
void Subtractor::FilterMisadjustmentEstimator::Update(float e2, float y2) {
e2_acum_ += e2;
y2_acum_ += y2;
if (++n_blocks_acum_ == n_blocks_) {
if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) {
float update = (e2_acum_ / y2_acum_);
if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) {
overhang_ = 4; // Duration equal to blockSizeMs * n_blocks_ * 4
} else {
overhang_ = std::max(overhang_ - 1, 0);
}
if ((update < inv_misadjustment_) || (overhang_ > 0)) {
inv_misadjustment_ += 0.1f * (update - inv_misadjustment_);
}
}
e2_acum_ = 0.f;
y2_acum_ = 0.f;
n_blocks_acum_ = 0;
}
}
void Subtractor::FilterMisadjustmentEstimator::Reset() {
e2_acum_ = 0.f;
y2_acum_ = 0.f;
n_blocks_acum_ = 0;
inv_misadjustment_ = 0.f;
overhang_ = 0.f;
}
void Subtractor::FilterMisadjustmentEstimator::Dump(
ApmDataDumper* data_dumper) const {
data_dumper->DumpRaw("aec3_inv_misadjustment_factor", inv_misadjustment_);
}
} // namespace webrtc

View file

@ -14,6 +14,7 @@
#include <algorithm>
#include <array>
#include <vector>
#include "math.h"
#include "modules/audio_processing/aec3/adaptive_fir_filter.h"
#include "modules/audio_processing/aec3/aec3_common.h"
@ -78,11 +79,43 @@ class Subtractor {
}
private:
class FilterMisadjustmentEstimator {
public:
FilterMisadjustmentEstimator() = default;
~FilterMisadjustmentEstimator() = default;
// Update the misadjustment estimator.
void Update(float e2, float y2);
// GetMisadjustment() Returns a recommended scale for the filter so the
// prediction error energy gets closer to the energy that is seen at the
// microphone input.
float GetMisadjustment() const {
RTC_DCHECK_GT(inv_misadjustment_, 0.0f);
// It is not aiming to adjust all the estimated mismatch. Instead,
// it adjusts half of that estimated mismatch.
return 2.f / sqrtf(inv_misadjustment_);
}
// Returns true if the prediciton error energy is significantly larger
// than the microphone signal energy and, therefore, an adjustment is
// recommended.
bool IsAdjustmentNeeded() const { return inv_misadjustment_ > 10.f; }
void Reset();
void Dump(ApmDataDumper* data_dumper) const;
private:
const int n_blocks_ = 4;
int n_blocks_acum_ = 0;
float e2_acum_ = 0.f;
float y2_acum_ = 0.f;
float inv_misadjustment_ = 0.f;
int overhang_ = 0.f;
};
const Aec3Fft fft_;
ApmDataDumper* data_dumper_;
const Aec3Optimization optimization_;
const EchoCanceller3Config config_;
const bool adaptation_during_saturation_;
const bool enable_misadjustment_estimator_;
AdaptiveFirFilter main_filter_;
AdaptiveFirFilter shadow_filter_;
MainFilterUpdateGain G_main_;
@ -91,6 +124,7 @@ class Subtractor {
bool main_filter_once_converged_ = false;
bool shadow_filter_converged_ = false;
bool main_filter_diverged_ = false;
FilterMisadjustmentEstimator filter_misadjustment_estimator_;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Subtractor);
};

View file

@ -36,6 +36,24 @@ struct SubtractorOutput {
E2_main.fill(0.f);
E2_shadow.fill(0.f);
}
void ScaleOutputMainFilter(float factor) {
for (auto& s : s_main) {
s *= factor;
}
for (auto& e : e_main) {
e *= factor;
}
for (auto& E2 : E2_main) {
E2 *= factor * factor;
}
for (auto& re : E_main.re) {
re *= factor;
}
for (auto& im : E_main.im) {
im *= factor;
}
}
};
} // namespace webrtc