Support AVX2/FMA intrinsics in audio FIR filter

Bug: webrtc:11663
Test: ./common_audio_unittests --gtest_filter=FIRFilterTest.*
Change-Id: I4c2bd8577e9d964c8a424f5c781a77c1692da238
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/178627
Reviewed-by: Per Åhgren <peah@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#31963}
This commit is contained in:
Zhaoliang Ma 2020-08-19 14:43:23 +08:00 committed by Commit Bot
parent bcdfc8975e
commit 7ad1011a19
4 changed files with 139 additions and 7 deletions

View file

@ -265,7 +265,11 @@ if (current_cpu == "x86" || current_cpu == "x64") {
}
rtc_library("common_audio_avx2") {
sources = [ "resampler/sinc_resampler_avx2.cc" ]
sources = [
"fir_filter_avx2.cc",
"fir_filter_avx2.h",
"resampler/sinc_resampler_avx2.cc",
]
if (is_win) {
cflags = [ "/arch:AVX2" ]

View file

@ -0,0 +1,88 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "common_audio/fir_filter_avx2.h"
#include <immintrin.h>
#include <stdint.h>
#include <string.h>
#include <xmmintrin.h>
#include "rtc_base/checks.h"
#include "rtc_base/memory/aligned_malloc.h"
namespace webrtc {
FIRFilterAVX2::FIRFilterAVX2(const float* unaligned_coefficients,
size_t unaligned_coefficients_length,
size_t max_input_length)
: // Closest higher multiple of eight.
coefficients_length_((unaligned_coefficients_length + 7) & ~0x07),
state_length_(coefficients_length_ - 1),
coefficients_(static_cast<float*>(
AlignedMalloc(sizeof(float) * coefficients_length_, 32))),
state_(static_cast<float*>(
AlignedMalloc(sizeof(float) * (max_input_length + state_length_),
32))) {
// Add zeros at the end of the coefficients.
RTC_DCHECK_GT(coefficients_length_, unaligned_coefficients_length);
size_t padding = coefficients_length_ - unaligned_coefficients_length;
memset(coefficients_.get(), 0, padding * sizeof(coefficients_[0]));
// The coefficients are reversed to compensate for the order in which the
// input samples are acquired (most recent last).
for (size_t i = 0; i < unaligned_coefficients_length; ++i) {
coefficients_[i + padding] =
unaligned_coefficients[unaligned_coefficients_length - i - 1];
}
memset(state_.get(), 0,
(max_input_length + state_length_) * sizeof(state_[0]));
}
FIRFilterAVX2::~FIRFilterAVX2() = default;
void FIRFilterAVX2::Filter(const float* in, size_t length, float* out) {
RTC_DCHECK_GT(length, 0);
memcpy(&state_[state_length_], in, length * sizeof(*in));
// Convolves the input signal |in| with the filter kernel |coefficients_|
// taking into account the previous state.
for (size_t i = 0; i < length; ++i) {
float* in_ptr = &state_[i];
float* coef_ptr = coefficients_.get();
__m256 m_sum = _mm256_setzero_ps();
__m256 m_in;
// Depending on if the pointer is aligned with 32 bytes or not it is loaded
// differently.
if (reinterpret_cast<uintptr_t>(in_ptr) & 0x1F) {
for (size_t j = 0; j < coefficients_length_; j += 8) {
m_in = _mm256_loadu_ps(in_ptr + j);
m_sum = _mm256_fmadd_ps(m_in, _mm256_load_ps(coef_ptr + j), m_sum);
}
} else {
for (size_t j = 0; j < coefficients_length_; j += 8) {
m_in = _mm256_load_ps(in_ptr + j);
m_sum = _mm256_fmadd_ps(m_in, _mm256_load_ps(coef_ptr + j), m_sum);
}
}
__m128 m128_sum = _mm_add_ps(_mm256_extractf128_ps(m_sum, 0),
_mm256_extractf128_ps(m_sum, 1));
m128_sum = _mm_add_ps(_mm_movehl_ps(m128_sum, m128_sum), m128_sum);
_mm_store_ss(out + i,
_mm_add_ss(m128_sum, _mm_shuffle_ps(m128_sum, m128_sum, 1)));
}
// Update current state.
memmove(state_.get(), &state_[length], state_length_ * sizeof(state_[0]));
}
} // namespace webrtc

View file

@ -0,0 +1,41 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef COMMON_AUDIO_FIR_FILTER_AVX2_H_
#define COMMON_AUDIO_FIR_FILTER_AVX2_H_
#include <stddef.h>
#include <memory>
#include "common_audio/fir_filter.h"
#include "rtc_base/memory/aligned_malloc.h"
namespace webrtc {
class FIRFilterAVX2 : public FIRFilter {
public:
FIRFilterAVX2(const float* coefficients,
size_t coefficients_length,
size_t max_input_length);
~FIRFilterAVX2() override;
void Filter(const float* in, size_t length, float* out) override;
private:
const size_t coefficients_length_;
const size_t state_length_;
std::unique_ptr<float[], AlignedFreeDeleter> coefficients_;
std::unique_ptr<float[], AlignedFreeDeleter> state_;
};
} // namespace webrtc
#endif // COMMON_AUDIO_FIR_FILTER_AVX2_H_

View file

@ -17,6 +17,7 @@
#if defined(WEBRTC_HAS_NEON)
#include "common_audio/fir_filter_neon.h"
#elif defined(WEBRTC_ARCH_X86_FAMILY)
#include "common_audio/fir_filter_avx2.h"
#include "common_audio/fir_filter_sse.h"
#include "system_wrappers/include/cpu_features_wrapper.h" // kSSE2, WebRtc_G...
#endif
@ -34,18 +35,16 @@ FIRFilter* CreateFirFilter(const float* coefficients,
FIRFilter* filter = nullptr;
// If we know the minimum architecture at compile time, avoid CPU detection.
#if defined(WEBRTC_ARCH_X86_FAMILY)
#if defined(__SSE2__)
filter =
new FIRFilterSSE2(coefficients, coefficients_length, max_input_length);
#else
// x86 CPU detection required.
if (WebRtc_GetCPUInfo(kSSE2)) {
if (WebRtc_GetCPUInfo(kAVX2)) {
filter =
new FIRFilterAVX2(coefficients, coefficients_length, max_input_length);
} else if (WebRtc_GetCPUInfo(kSSE2)) {
filter =
new FIRFilterSSE2(coefficients, coefficients_length, max_input_length);
} else {
filter = new FIRFilterC(coefficients, coefficients_length);
}
#endif
#elif defined(WEBRTC_HAS_NEON)
filter =
new FIRFilterNEON(coefficients, coefficients_length, max_input_length);