mirror of
https://github.com/mollyim/webrtc.git
synced 2025-05-16 15:20:42 +01:00

This CL overrides the power-based suppressor gain decision with a coherence based descision for the cases when that indicates a higher suppressor gain. Bug: webrtc:9159,chromium:833801 Change-Id: I0e7d82ac1b8c70ffe9d45907559bb14b1b849d71 Reviewed-on: https://webrtc-review.googlesource.com/71660 Commit-Queue: Per Åhgren <peah@webrtc.org> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Cr-Commit-Position: refs/heads/master@{#22997}
257 lines
9.3 KiB
C++
257 lines
9.3 KiB
C++
/*
|
|
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
|
*
|
|
* Use of this source code is governed by a BSD-style license
|
|
* that can be found in the LICENSE file in the root of the source
|
|
* tree. An additional intellectual property rights grant can be found
|
|
* in the file PATENTS. All contributing project authors may
|
|
* be found in the AUTHORS file in the root of the source tree.
|
|
*/
|
|
|
|
#include "modules/audio_processing/aec3/coherence_gain.h"
|
|
|
|
#include <math.h>
|
|
|
|
#include <algorithm>
|
|
|
|
#include "rtc_base/checks.h"
|
|
|
|
namespace webrtc {
|
|
|
|
namespace {
|
|
|
|
// Matlab code to produce table:
|
|
// overDriveCurve = [sqrt(linspace(0,1,65))' + 1];
|
|
// fprintf(1, '\t%.4f, %.4f, %.4f, %.4f, %.4f, %.4f,\n', overDriveCurve);
|
|
const float kOverDriveCurve[kFftLengthBy2Plus1] = {
|
|
1.0000f, 1.1250f, 1.1768f, 1.2165f, 1.2500f, 1.2795f, 1.3062f, 1.3307f,
|
|
1.3536f, 1.3750f, 1.3953f, 1.4146f, 1.4330f, 1.4507f, 1.4677f, 1.4841f,
|
|
1.5000f, 1.5154f, 1.5303f, 1.5449f, 1.5590f, 1.5728f, 1.5863f, 1.5995f,
|
|
1.6124f, 1.6250f, 1.6374f, 1.6495f, 1.6614f, 1.6731f, 1.6847f, 1.6960f,
|
|
1.7071f, 1.7181f, 1.7289f, 1.7395f, 1.7500f, 1.7603f, 1.7706f, 1.7806f,
|
|
1.7906f, 1.8004f, 1.8101f, 1.8197f, 1.8292f, 1.8385f, 1.8478f, 1.8570f,
|
|
1.8660f, 1.8750f, 1.8839f, 1.8927f, 1.9014f, 1.9100f, 1.9186f, 1.9270f,
|
|
1.9354f, 1.9437f, 1.9520f, 1.9601f, 1.9682f, 1.9763f, 1.9843f, 1.9922f,
|
|
2.0000f};
|
|
|
|
// Matlab code to produce table:
|
|
// weightCurve = [0 ; 0.3 * sqrt(linspace(0,1,64))' + 0.1];
|
|
// fprintf(1, '\t%.4f, %.4f, %.4f, %.4f, %.4f, %.4f,\n', weightCurve);
|
|
const float kWeightCurve[kFftLengthBy2Plus1] = {
|
|
0.0000f, 0.1000f, 0.1378f, 0.1535f, 0.1655f, 0.1756f, 0.1845f, 0.1926f,
|
|
0.2000f, 0.2069f, 0.2134f, 0.2195f, 0.2254f, 0.2309f, 0.2363f, 0.2414f,
|
|
0.2464f, 0.2512f, 0.2558f, 0.2604f, 0.2648f, 0.2690f, 0.2732f, 0.2773f,
|
|
0.2813f, 0.2852f, 0.2890f, 0.2927f, 0.2964f, 0.3000f, 0.3035f, 0.3070f,
|
|
0.3104f, 0.3138f, 0.3171f, 0.3204f, 0.3236f, 0.3268f, 0.3299f, 0.3330f,
|
|
0.3360f, 0.3390f, 0.3420f, 0.3449f, 0.3478f, 0.3507f, 0.3535f, 0.3563f,
|
|
0.3591f, 0.3619f, 0.3646f, 0.3673f, 0.3699f, 0.3726f, 0.3752f, 0.3777f,
|
|
0.3803f, 0.3828f, 0.3854f, 0.3878f, 0.3903f, 0.3928f, 0.3952f, 0.3976f,
|
|
0.4000f};
|
|
|
|
int CmpFloat(const void* a, const void* b) {
|
|
const float* da = static_cast<const float*>(a);
|
|
const float* db = static_cast<const float*>(b);
|
|
return (*da > *db) - (*da < *db);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
CoherenceGain::CoherenceGain(int sample_rate_hz, size_t num_bands_to_compute)
|
|
: num_bands_to_compute_(num_bands_to_compute),
|
|
sample_rate_scaler_(sample_rate_hz >= 16000 ? 2 : 1) {
|
|
spectra_.Cye.Clear();
|
|
spectra_.Cxy.Clear();
|
|
spectra_.Pe.fill(0.f);
|
|
// Initialize to 1 in order to prevent numerical instability in the first
|
|
// block.
|
|
spectra_.Py.fill(1.f);
|
|
spectra_.Px.fill(1.f);
|
|
}
|
|
|
|
CoherenceGain::~CoherenceGain() = default;
|
|
|
|
void CoherenceGain::ComputeGain(const FftData& E,
|
|
const FftData& X,
|
|
const FftData& Y,
|
|
rtc::ArrayView<float> gain) {
|
|
std::array<float, kFftLengthBy2Plus1> coherence_ye;
|
|
std::array<float, kFftLengthBy2Plus1> coherence_xy;
|
|
|
|
UpdateCoherenceSpectra(E, X, Y);
|
|
ComputeCoherence(coherence_ye, coherence_xy);
|
|
FormSuppressionGain(coherence_ye, coherence_xy, gain);
|
|
}
|
|
|
|
// Updates the following smoothed Power Spectral Densities (PSD):
|
|
// - sd : near-end
|
|
// - se : residual echo
|
|
// - sx : far-end
|
|
// - sde : cross-PSD of near-end and residual echo
|
|
// - sxd : cross-PSD of near-end and far-end
|
|
//
|
|
void CoherenceGain::UpdateCoherenceSpectra(const FftData& E,
|
|
const FftData& X,
|
|
const FftData& Y) {
|
|
const float s = sample_rate_scaler_ == 1 ? 0.9f : 0.92f;
|
|
const float one_minus_s = 1.f - s;
|
|
auto& c = spectra_;
|
|
|
|
for (size_t i = 0; i < c.Py.size(); i++) {
|
|
c.Py[i] =
|
|
s * c.Py[i] + one_minus_s * (Y.re[i] * Y.re[i] + Y.im[i] * Y.im[i]);
|
|
c.Pe[i] =
|
|
s * c.Pe[i] + one_minus_s * (E.re[i] * E.re[i] + E.im[i] * E.im[i]);
|
|
// We threshold here to protect against the ill-effects of a zero farend.
|
|
// The threshold is not arbitrarily chosen, but balances protection and
|
|
// adverse interaction with the algorithm's tuning.
|
|
|
|
// Threshold to protect against the ill-effects of a zero far-end.
|
|
c.Px[i] =
|
|
s * c.Px[i] +
|
|
one_minus_s * std::max(X.re[i] * X.re[i] + X.im[i] * X.im[i], 15.f);
|
|
|
|
c.Cye.re[i] =
|
|
s * c.Cye.re[i] + one_minus_s * (Y.re[i] * E.re[i] + Y.im[i] * E.im[i]);
|
|
c.Cye.im[i] =
|
|
s * c.Cye.im[i] + one_minus_s * (Y.re[i] * E.im[i] - Y.im[i] * E.re[i]);
|
|
|
|
c.Cxy.re[i] =
|
|
s * c.Cxy.re[i] + one_minus_s * (Y.re[i] * X.re[i] + Y.im[i] * X.im[i]);
|
|
c.Cxy.im[i] =
|
|
s * c.Cxy.im[i] + one_minus_s * (Y.re[i] * X.im[i] - Y.im[i] * X.re[i]);
|
|
}
|
|
}
|
|
|
|
void CoherenceGain::FormSuppressionGain(
|
|
rtc::ArrayView<const float> coherence_ye,
|
|
rtc::ArrayView<const float> coherence_xy,
|
|
rtc::ArrayView<float> gain) {
|
|
RTC_DCHECK_EQ(kFftLengthBy2Plus1, coherence_ye.size());
|
|
RTC_DCHECK_EQ(kFftLengthBy2Plus1, coherence_xy.size());
|
|
RTC_DCHECK_EQ(kFftLengthBy2Plus1, gain.size());
|
|
constexpr int kPrefBandSize = 24;
|
|
auto& gs = gain_state_;
|
|
std::array<float, kPrefBandSize> h_nl_pref;
|
|
float h_nl_fb = 0;
|
|
float h_nl_fb_low = 0;
|
|
const int pref_band_size = kPrefBandSize / sample_rate_scaler_;
|
|
const int min_pref_band = 4 / sample_rate_scaler_;
|
|
|
|
float h_nl_de_avg = 0.f;
|
|
float h_nl_xd_avg = 0.f;
|
|
for (int i = min_pref_band; i < pref_band_size + min_pref_band; ++i) {
|
|
h_nl_xd_avg += coherence_xy[i];
|
|
h_nl_de_avg += coherence_ye[i];
|
|
}
|
|
h_nl_xd_avg /= pref_band_size;
|
|
h_nl_xd_avg = 1 - h_nl_xd_avg;
|
|
h_nl_de_avg /= pref_band_size;
|
|
|
|
if (h_nl_xd_avg < 0.75f && h_nl_xd_avg < gs.h_nl_xd_avg_min) {
|
|
gs.h_nl_xd_avg_min = h_nl_xd_avg;
|
|
}
|
|
|
|
if (h_nl_de_avg > 0.98f && h_nl_xd_avg > 0.9f) {
|
|
gs.near_state = true;
|
|
} else if (h_nl_de_avg < 0.95f || h_nl_xd_avg < 0.8f) {
|
|
gs.near_state = false;
|
|
}
|
|
|
|
std::array<float, kFftLengthBy2Plus1> h_nl;
|
|
if (gs.h_nl_xd_avg_min == 1) {
|
|
gs.overdrive = 15.f;
|
|
|
|
if (gs.near_state) {
|
|
std::copy(coherence_ye.begin(), coherence_ye.end(), h_nl.begin());
|
|
h_nl_fb = h_nl_de_avg;
|
|
h_nl_fb_low = h_nl_de_avg;
|
|
} else {
|
|
for (size_t i = 0; i < h_nl.size(); ++i) {
|
|
h_nl[i] = 1 - coherence_xy[i];
|
|
h_nl[i] = std::max(h_nl[i], 0.f);
|
|
}
|
|
h_nl_fb = h_nl_xd_avg;
|
|
h_nl_fb_low = h_nl_xd_avg;
|
|
}
|
|
} else {
|
|
if (gs.near_state) {
|
|
std::copy(coherence_ye.begin(), coherence_ye.end(), h_nl.begin());
|
|
h_nl_fb = h_nl_de_avg;
|
|
h_nl_fb_low = h_nl_de_avg;
|
|
} else {
|
|
for (size_t i = 0; i < h_nl.size(); ++i) {
|
|
h_nl[i] = std::min(coherence_ye[i], 1 - coherence_xy[i]);
|
|
h_nl[i] = std::max(h_nl[i], 0.f);
|
|
}
|
|
|
|
// Select an order statistic from the preferred bands.
|
|
// TODO(peah): Using quicksort now, but a selection algorithm may be
|
|
// preferred.
|
|
std::copy(h_nl.begin() + min_pref_band,
|
|
h_nl.begin() + min_pref_band + pref_band_size,
|
|
h_nl_pref.begin());
|
|
std::qsort(h_nl_pref.data(), pref_band_size, sizeof(float), CmpFloat);
|
|
|
|
constexpr float kPrefBandQuant = 0.75f;
|
|
h_nl_fb = h_nl_pref[static_cast<int>(
|
|
floor(kPrefBandQuant * (pref_band_size - 1)))];
|
|
constexpr float kPrefBandQuantLow = 0.5f;
|
|
h_nl_fb_low = h_nl_pref[static_cast<int>(
|
|
floor(kPrefBandQuantLow * (pref_band_size - 1)))];
|
|
}
|
|
}
|
|
|
|
// Track the local filter minimum to determine suppression overdrive.
|
|
if (h_nl_fb_low < 0.6f && h_nl_fb_low < gs.h_nl_fb_local_min) {
|
|
gs.h_nl_fb_local_min = h_nl_fb_low;
|
|
gs.h_nl_fb_min = h_nl_fb_low;
|
|
gs.h_nl_new_min = 1;
|
|
gs.h_nl_min_ctr = 0;
|
|
}
|
|
gs.h_nl_fb_local_min =
|
|
std::min(gs.h_nl_fb_local_min + 0.0008f / sample_rate_scaler_, 1.f);
|
|
gs.h_nl_xd_avg_min =
|
|
std::min(gs.h_nl_xd_avg_min + 0.0006f / sample_rate_scaler_, 1.f);
|
|
|
|
if (gs.h_nl_new_min == 1) {
|
|
++gs.h_nl_min_ctr;
|
|
}
|
|
if (gs.h_nl_min_ctr == 2) {
|
|
gs.h_nl_new_min = 0;
|
|
gs.h_nl_min_ctr = 0;
|
|
constexpr float epsilon = 1e-10f;
|
|
gs.overdrive = std::max(
|
|
-18.4f / static_cast<float>(log(gs.h_nl_fb_min + epsilon) + epsilon),
|
|
15.f);
|
|
}
|
|
|
|
// Smooth the overdrive.
|
|
if (gs.overdrive < gs.overdrive_scaling) {
|
|
gs.overdrive_scaling = 0.99f * gs.overdrive_scaling + 0.01f * gs.overdrive;
|
|
} else {
|
|
gs.overdrive_scaling = 0.9f * gs.overdrive_scaling + 0.1f * gs.overdrive;
|
|
}
|
|
|
|
// Apply the overdrive.
|
|
RTC_DCHECK_LE(num_bands_to_compute_, gain.size());
|
|
for (size_t i = 0; i < num_bands_to_compute_; ++i) {
|
|
if (h_nl[i] > h_nl_fb) {
|
|
h_nl[i] = kWeightCurve[i] * h_nl_fb + (1 - kWeightCurve[i]) * h_nl[i];
|
|
}
|
|
gain[i] = powf(h_nl[i], gs.overdrive_scaling * kOverDriveCurve[i]);
|
|
}
|
|
}
|
|
|
|
void CoherenceGain::ComputeCoherence(rtc::ArrayView<float> coherence_ye,
|
|
rtc::ArrayView<float> coherence_xy) const {
|
|
const auto& c = spectra_;
|
|
constexpr float epsilon = 1e-10f;
|
|
for (size_t i = 0; i < coherence_ye.size(); ++i) {
|
|
coherence_ye[i] = (c.Cye.re[i] * c.Cye.re[i] + c.Cye.im[i] * c.Cye.im[i]) /
|
|
(c.Py[i] * c.Pe[i] + epsilon);
|
|
coherence_xy[i] = (c.Cxy.re[i] * c.Cxy.re[i] + c.Cxy.im[i] * c.Cxy.im[i]) /
|
|
(c.Px[i] * c.Py[i] + epsilon);
|
|
}
|
|
}
|
|
|
|
} // namespace webrtc
|