Make transient suppression optionally excludable via defines

This allows clients to exclude the transient suppression submodule from WebRTC builds, by defining WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR.

The changes have been shown to be bitexact for a test dataset (when the flag is _not_ defined.)

No-Try: True
Bug: webrtc:11226, webrtc:11292
Change-Id: I6931c82a280a9b40a53ee1c2a9820ed9e674a9a5
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/171421
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#30978}
This commit is contained in:
saza 2020-04-01 15:24:40 +02:00 committed by Commit Bot
parent fc23cc07e2
commit aa42ecde9a
13 changed files with 299 additions and 134 deletions

View file

@ -277,6 +277,10 @@ config("common_config") {
defines += [ "RTC_DISABLE_METRICS" ]
}
if (rtc_exclude_transient_suppressor) {
defines += [ "WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR" ]
}
cflags = []
if (build_with_chromium) {

View file

@ -184,7 +184,8 @@ rtc_library("audio_processing") {
"agc2:fixed_digital",
"agc2:gain_applier",
"ns",
"transient:transient_suppressor",
"transient:transient_suppressor_api",
"transient:transient_suppressor_creator",
"vad",
"//third_party/abseil-cpp/absl/types:optional",
]

View file

@ -27,6 +27,7 @@
#include "modules/audio_processing/common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "modules/audio_processing/transient/transient_suppressor_creator.h"
#include "rtc_base/atomic_ops.h"
#include "rtc_base/checks.h"
#include "rtc_base/constructor_magic.h"
@ -1635,12 +1636,18 @@ bool AudioProcessingImpl::UpdateActiveSubmoduleStates() {
void AudioProcessingImpl::InitializeTransientSuppressor() {
if (config_.transient_suppression.enabled) {
// Attempt to create a transient suppressor, if one is not already created.
if (!submodules_.transient_suppressor) {
submodules_.transient_suppressor.reset(new TransientSuppressor());
submodules_.transient_suppressor = CreateTransientSuppressor();
}
if (submodules_.transient_suppressor) {
submodules_.transient_suppressor->Initialize(
proc_fullband_sample_rate_hz(), capture_nonlocked_.split_rate,
num_proc_channels());
} else {
RTC_LOG(LS_WARNING)
<< "No transient suppressor created (probably disabled)";
}
submodules_.transient_suppressor->Initialize(proc_fullband_sample_rate_hz(),
capture_nonlocked_.split_rate,
num_proc_channels());
} else {
submodules_.transient_suppressor.reset();
}
@ -1843,28 +1850,28 @@ void AudioProcessingImpl::InitializeNoiseSuppressor() {
submodules_.noise_suppressor.reset();
if (config_.noise_suppression.enabled) {
auto map_level =
[](AudioProcessing::Config::NoiseSuppression::Level level) {
using NoiseSuppresionConfig =
AudioProcessing::Config::NoiseSuppression;
switch (level) {
case NoiseSuppresionConfig::kLow:
return NsConfig::SuppressionLevel::k6dB;
case NoiseSuppresionConfig::kModerate:
return NsConfig::SuppressionLevel::k12dB;
case NoiseSuppresionConfig::kHigh:
return NsConfig::SuppressionLevel::k18dB;
case NoiseSuppresionConfig::kVeryHigh:
return NsConfig::SuppressionLevel::k21dB;
default:
RTC_NOTREACHED();
}
};
auto map_level =
[](AudioProcessing::Config::NoiseSuppression::Level level) {
using NoiseSuppresionConfig =
AudioProcessing::Config::NoiseSuppression;
switch (level) {
case NoiseSuppresionConfig::kLow:
return NsConfig::SuppressionLevel::k6dB;
case NoiseSuppresionConfig::kModerate:
return NsConfig::SuppressionLevel::k12dB;
case NoiseSuppresionConfig::kHigh:
return NsConfig::SuppressionLevel::k18dB;
case NoiseSuppresionConfig::kVeryHigh:
return NsConfig::SuppressionLevel::k21dB;
default:
RTC_NOTREACHED();
}
};
NsConfig cfg;
cfg.target_level = map_level(config_.noise_suppression.level);
submodules_.noise_suppressor = std::make_unique<NoiseSuppressor>(
cfg, proc_sample_rate_hz(), num_proc_channels());
NsConfig cfg;
cfg.target_level = map_level(config_.noise_suppression.level);
submodules_.noise_suppressor = std::make_unique<NoiseSuppressor>(
cfg, proc_sample_rate_hz(), num_proc_channels());
}
}

View file

@ -8,7 +8,28 @@
import("../../../webrtc.gni")
rtc_library("transient_suppressor") {
rtc_source_set("transient_suppressor_api") {
sources = [ "transient_suppressor.h" ]
}
rtc_library("transient_suppressor_creator") {
sources = [
"transient_suppressor_creator.cc",
"transient_suppressor_creator.h",
]
deps = [
":transient_suppressor_api",
":transient_suppressor_impl",
]
}
rtc_library("transient_suppressor_impl") {
visibility = [
":transient_suppressor_creator",
":transient_suppression_test",
":transient_suppression_unittests",
":click_annotate",
]
sources = [
"common.h",
"daubechies_8_wavelet_coeffs.h",
@ -17,8 +38,8 @@ rtc_library("transient_suppressor") {
"moving_moments.h",
"transient_detector.cc",
"transient_detector.h",
"transient_suppressor.cc",
"transient_suppressor.h",
"transient_suppressor_impl.cc",
"transient_suppressor_impl.h",
"windows_private.h",
"wpd_node.cc",
"wpd_node.h",
@ -26,6 +47,7 @@ rtc_library("transient_suppressor") {
"wpd_tree.h",
]
deps = [
":transient_suppressor_api",
"../../../common_audio:common_audio",
"../../../common_audio:common_audio_c",
"../../../common_audio:fir_filter",
@ -46,7 +68,7 @@ if (rtc_include_tests) {
"file_utils.h",
]
deps = [
":transient_suppressor",
":transient_suppressor_impl",
"..:audio_processing",
"../../../rtc_base/system:file_wrapper",
"../../../system_wrappers",
@ -61,7 +83,7 @@ if (rtc_include_tests) {
"transient_suppression_test.cc",
]
deps = [
":transient_suppressor",
":transient_suppressor_impl",
"..:audio_processing",
"../../../common_audio",
"../../../rtc_base:rtc_base_approved",
@ -90,7 +112,7 @@ if (rtc_include_tests) {
"wpd_tree_unittest.cc",
]
deps = [
":transient_suppressor",
":transient_suppressor_impl",
"../../../rtc_base:stringutils",
"../../../rtc_base/system:file_wrapper",
"../../../test:fileutils",

View file

@ -20,7 +20,7 @@
#include "absl/flags/parse.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc/agc.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
@ -165,7 +165,7 @@ void void_main() {
Agc agc;
TransientSuppressor suppressor;
TransientSuppressorImpl suppressor;
suppressor.Initialize(absl::GetFlag(FLAGS_sample_rate_hz), detection_rate_hz,
absl::GetFlag(FLAGS_num_channels));

View file

@ -1,5 +1,5 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
@ -13,23 +13,19 @@
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
class TransientDetector;
// Detects transients in an audio stream and suppress them using a simple
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressor {
public:
TransientSuppressor();
~TransientSuppressor();
virtual ~TransientSuppressor() {}
int Initialize(int sample_rate_hz, int detector_rate_hz, int num_channels);
virtual int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) = 0;
// Processes a |data| chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
@ -48,71 +44,15 @@ class TransientSuppressor {
// always be set to 1.
// |key_pressed| determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed);
private:
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorTest,
TypingDetectionLogicWorksAsExpectedForMono);
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
void UpdateKeypress(bool key_pressed);
void UpdateRestoration(float voice_probability);
void UpdateBuffers(float* data);
void HardRestoration(float* spectral_mean);
void SoftRestoration(float* spectral_mean);
std::unique_ptr<TransientDetector> detector_;
size_t data_length_;
size_t detection_length_;
size_t analysis_length_;
size_t buffer_delay_;
size_t complex_analysis_length_;
int num_channels_;
// Input buffer where the original samples are stored.
std::unique_ptr<float[]> in_buffer_;
std::unique_ptr<float[]> detection_buffer_;
// Output buffer where the restored samples are stored.
std::unique_ptr<float[]> out_buffer_;
// Arrays for fft.
std::unique_ptr<size_t[]> ip_;
std::unique_ptr<float[]> wfft_;
std::unique_ptr<float[]> spectral_mean_;
// Stores the data for the fft.
std::unique_ptr<float[]> fft_buffer_;
std::unique_ptr<float[]> magnitudes_;
const float* window_;
std::unique_ptr<float[]> mean_factor_;
float detector_smoothed_;
int keypress_counter_;
int chunks_since_keypress_;
bool detection_enabled_;
bool suppression_enabled_;
bool use_hard_restoration_;
int chunks_since_voice_change_;
uint32_t seed_;
bool using_reference_;
virtual int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) = 0;
};
} // namespace webrtc

View file

@ -0,0 +1,27 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor_creator.h"
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
namespace webrtc {
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor() {
#ifdef WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR
return nullptr;
#else
return std::make_unique<TransientSuppressorImpl>();
#endif
}
} // namespace webrtc

View file

@ -0,0 +1,26 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_CREATOR_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_CREATOR_H_
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor.h"
namespace webrtc {
// Creates a transient suppressor.
// Will return nullptr if WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR is defined.
std::unique_ptr<TransientSuppressor> CreateTransientSuppressor();
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_CREATOR_H_

View file

@ -8,13 +8,15 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include <string.h>
#include <algorithm>
#include <cmath>
#include <complex>
#include <deque>
#include <limits>
#include <set>
#include "common_audio/include/audio_util.h"
@ -22,6 +24,7 @@
#include "common_audio/third_party/fft4g/fft4g.h"
#include "modules/audio_processing/transient/common.h"
#include "modules/audio_processing/transient/transient_detector.h"
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/windows_private.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
@ -43,7 +46,7 @@ float ComplexMagnitude(float a, float b) {
} // namespace
TransientSuppressor::TransientSuppressor()
TransientSuppressorImpl::TransientSuppressorImpl()
: data_length_(0),
detection_length_(0),
analysis_length_(0),
@ -61,11 +64,11 @@ TransientSuppressor::TransientSuppressor()
seed_(182),
using_reference_(false) {}
TransientSuppressor::~TransientSuppressor() {}
TransientSuppressorImpl::~TransientSuppressorImpl() {}
int TransientSuppressor::Initialize(int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
int TransientSuppressorImpl::Initialize(int sample_rate_hz,
int detection_rate_hz,
int num_channels) {
switch (sample_rate_hz) {
case ts::kSampleRate8kHz:
analysis_length_ = 128u;
@ -155,15 +158,15 @@ int TransientSuppressor::Initialize(int sample_rate_hz,
return 0;
}
int TransientSuppressor::Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) {
int TransientSuppressorImpl::Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) {
if (!data || data_length != data_length_ || num_channels != num_channels_ ||
detection_length != detection_length_ || voice_probability < 0 ||
voice_probability > 1) {
@ -222,9 +225,9 @@ int TransientSuppressor::Suppress(float* data,
// This should only be called when detection is enabled. UpdateBuffers() must
// have been called. At return, |out_buffer_| will be filled with the
// processed output.
void TransientSuppressor::Suppress(float* in_ptr,
float* spectral_mean,
float* out_ptr) {
void TransientSuppressorImpl::Suppress(float* in_ptr,
float* spectral_mean,
float* out_ptr) {
// Go to frequency domain.
for (size_t i = 0; i < analysis_length_; ++i) {
// TODO(aluebs): Rename windows
@ -270,7 +273,7 @@ void TransientSuppressor::Suppress(float* in_ptr,
}
}
void TransientSuppressor::UpdateKeypress(bool key_pressed) {
void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) {
const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs; // 4 seconds.
@ -300,7 +303,7 @@ void TransientSuppressor::UpdateKeypress(bool key_pressed) {
}
}
void TransientSuppressor::UpdateRestoration(float voice_probability) {
void TransientSuppressorImpl::UpdateRestoration(float voice_probability) {
const int kHardRestorationOffsetDelay = 3;
const int kHardRestorationOnsetDelay = 80;
@ -323,7 +326,7 @@ void TransientSuppressor::UpdateRestoration(float voice_probability) {
// Shift buffers to make way for new data. Must be called after
// |detection_enabled_| is updated by UpdateKeypress().
void TransientSuppressor::UpdateBuffers(float* data) {
void TransientSuppressorImpl::UpdateBuffers(float* data) {
// TODO(aluebs): Change to ring buffer.
memmove(in_buffer_.get(), &in_buffer_[data_length_],
(buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
@ -350,7 +353,7 @@ void TransientSuppressor::UpdateBuffers(float* data) {
// Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
// the spectral mean. The attenuation depends on |detector_smoothed_|.
// If a restoration takes place, the |magnitudes_| are updated to the new value.
void TransientSuppressor::HardRestoration(float* spectral_mean) {
void TransientSuppressorImpl::HardRestoration(float* spectral_mean) {
const float detector_result =
1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
// To restore, we get the peaks in the spectrum. If higher than the previous
@ -377,7 +380,7 @@ void TransientSuppressor::HardRestoration(float* spectral_mean) {
// the spectral mean and that is lower than some function of the current block
// frequency mean. The attenuation depends on |detector_smoothed_|.
// If a restoration takes place, the |magnitudes_| are updated to the new value.
void TransientSuppressor::SoftRestoration(float* spectral_mean) {
void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) {
// Get the spectral magnitude mean of the current block.
float block_frequency_mean = 0;
for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {

View file

@ -0,0 +1,123 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
#define MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
class TransientDetector;
// Detects transients in an audio stream and suppress them using a simple
// restoration algorithm that attenuates unexpected spikes in the spectrum.
class TransientSuppressorImpl : public TransientSuppressor {
public:
TransientSuppressorImpl();
~TransientSuppressorImpl() override;
int Initialize(int sample_rate_hz,
int detector_rate_hz,
int num_channels) override;
// Processes a |data| chunk, and returns it with keystrokes suppressed from
// it. The float format is assumed to be int16 ranged. If there are more than
// one channel, the chunks are concatenated one after the other in |data|.
// |data_length| must be equal to |data_length_|.
// |num_channels| must be equal to |num_channels_|.
// A sub-band, ideally the higher, can be used as |detection_data|. If it is
// NULL, |data| is used for the detection too. The |detection_data| is always
// assumed mono.
// If a reference signal (e.g. keyboard microphone) is available, it can be
// passed in as |reference_data|. It is assumed mono and must have the same
// length as |data|. NULL is accepted if unavailable.
// This suppressor performs better if voice information is available.
// |voice_probability| is the probability of voice being present in this chunk
// of audio. If voice information is not available, |voice_probability| must
// always be set to 1.
// |key_pressed| determines if a key was pressed on this audio chunk.
// Returns 0 on success and -1 otherwise.
int Suppress(float* data,
size_t data_length,
int num_channels,
const float* detection_data,
size_t detection_length,
const float* reference_data,
size_t reference_length,
float voice_probability,
bool key_pressed) override;
private:
FRIEND_TEST_ALL_PREFIXES(TransientSuppressorImplTest,
TypingDetectionLogicWorksAsExpectedForMono);
void Suppress(float* in_ptr, float* spectral_mean, float* out_ptr);
void UpdateKeypress(bool key_pressed);
void UpdateRestoration(float voice_probability);
void UpdateBuffers(float* data);
void HardRestoration(float* spectral_mean);
void SoftRestoration(float* spectral_mean);
std::unique_ptr<TransientDetector> detector_;
size_t data_length_;
size_t detection_length_;
size_t analysis_length_;
size_t buffer_delay_;
size_t complex_analysis_length_;
int num_channels_;
// Input buffer where the original samples are stored.
std::unique_ptr<float[]> in_buffer_;
std::unique_ptr<float[]> detection_buffer_;
// Output buffer where the restored samples are stored.
std::unique_ptr<float[]> out_buffer_;
// Arrays for fft.
std::unique_ptr<size_t[]> ip_;
std::unique_ptr<float[]> wfft_;
std::unique_ptr<float[]> spectral_mean_;
// Stores the data for the fft.
std::unique_ptr<float[]> fft_buffer_;
std::unique_ptr<float[]> magnitudes_;
const float* window_;
std::unique_ptr<float[]> mean_factor_;
float detector_smoothed_;
int keypress_counter_;
int chunks_since_keypress_;
bool detection_enabled_;
bool suppression_enabled_;
bool use_hard_restoration_;
int chunks_since_voice_change_;
uint32_t seed_;
bool using_reference_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_TRANSIENT_TRANSIENT_SUPPRESSOR_IMPL_H_

View file

@ -8,17 +8,17 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/transient/transient_suppressor.h"
#include "modules/audio_processing/transient/transient_suppressor_impl.h"
#include "modules/audio_processing/transient/common.h"
#include "test/gtest.h"
namespace webrtc {
TEST(TransientSuppressorTest, TypingDetectionLogicWorksAsExpectedForMono) {
TEST(TransientSuppressorImplTest, TypingDetectionLogicWorksAsExpectedForMono) {
static const int kNumChannels = 1;
TransientSuppressor ts;
TransientSuppressorImpl ts;
ts.Initialize(ts::kSampleRate16kHz, ts::kSampleRate16kHz, kNumChannels);
// Each key-press enables detection.

View file

@ -106,3 +106,11 @@ argument `rtc_exclude_metrics_default` to true and GN will define the
macro for you.
[metrics_h]: https://webrtc.googlesource.com/src/+/master/system_wrappers/include/metrics.h
## `WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR`
The transient suppressor functionality in the audio processing module is not
always used. If you wish to exclude it from the build in order to preserve
binary size, then define the preprocessor macro
`WEBRTC_EXCLUDE_TRANSIENT_SUPPRESSOR`. If you use GN, you can just set the GN
argument `rtc_exclude_transient_suppressor` to true and GN will define the macro
for you.

View file

@ -255,6 +255,10 @@ declare_args() {
# Set this to true to disable webrtc metrics.
rtc_disable_metrics = false
# Set this to true to exclude the transient suppressor in the audio processing
# module from the build.
rtc_exclude_transient_suppressor = false
}
# Make it possible to provide custom locations for some libraries (move these