diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc index bea7868a91..ffedc85fcd 100644 --- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc +++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc @@ -15,6 +15,7 @@ #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_minmax.h" +#include "system_wrappers/include/field_trial.h" namespace webrtc { namespace { @@ -123,6 +124,8 @@ MatchedFilterLagAggregator::PreEchoLagAggregator::PreEchoLagAggregator( size_t max_filter_lag, size_t down_sampling_factor) : block_size_log2_(GetDownSamplingBlockSizeLog2(down_sampling_factor)), + penalize_high_delays_initial_phase_( + field_trial::IsEnabled("WebRTC-Aec3PenalyzeHighDelaysInitialPhase")), histogram_( ((max_filter_lag + 1) * down_sampling_factor) >> kBlockSizeLog2, 0) { @@ -152,9 +155,31 @@ void MatchedFilterLagAggregator::PreEchoLagAggregator::Aggregate( histogram_data_[histogram_data_index_] = pre_echo_block_size; ++histogram_[histogram_data_[histogram_data_index_]]; histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size(); - int pre_echo_candidate_block_size = - std::distance(histogram_.begin(), - std::max_element(histogram_.begin(), histogram_.end())); + int pre_echo_candidate_block_size = 0; + if (penalize_high_delays_initial_phase_ && + number_updates_ < kNumBlocksPerSecond * 2) { + number_updates_++; + float penalization_per_delay = 1.0f; + float max_histogram_value = -1.0f; + for (auto it = histogram_.begin(); + it + kMatchedFilterWindowSizeSubBlocks <= histogram_.end(); + it = it + kMatchedFilterWindowSizeSubBlocks) { + auto it_max_element = + std::max_element(it, it + kMatchedFilterWindowSizeSubBlocks); + float weighted_max_value = + static_cast(*it_max_element) * penalization_per_delay; + if (weighted_max_value > max_histogram_value) { + max_histogram_value = weighted_max_value; + pre_echo_candidate_block_size = + std::distance(histogram_.begin(), it_max_element); + } + penalization_per_delay *= 0.7f; + } + } else { + pre_echo_candidate_block_size = + std::distance(histogram_.begin(), + std::max_element(histogram_.begin(), histogram_.end())); + } pre_echo_candidate_ = (pre_echo_candidate_block_size << block_size_log2_); } diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h index c0598bf226..1287b38da0 100644 --- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h +++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h @@ -64,10 +64,12 @@ class MatchedFilterLagAggregator { private: const int block_size_log2_; + const bool penalize_high_delays_initial_phase_; std::array histogram_data_; std::vector histogram_; int histogram_data_index_ = 0; int pre_echo_candidate_ = 0; + int number_updates_ = 0; }; class HighestPeakAggregator {