blob: 4a8d7a796af32cf8b31f174e6c663063de903e53 [file] [log] [blame]
/*
* Copyright (c) 2025 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/neural_feature_extractor.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
#include "api/array_view.h"
#include "common_audio/window_generator.h"
#include "rtc_base/checks.h"
#include "third_party/pffft/src/pffft.h"
namespace webrtc {
namespace {
// Trained moodel expects [-1,1]-scaled signals while AEC3 and APM scale
// floating point signals up by 32768 to match 16-bit fixed-point formats, so we
// convert to [-1,1] scale here.
constexpr float kScale = 1.0f / 32768;
// Exponent used to compress the power spectra.
constexpr float kSpectrumCompressionExponent = 0.15f;
std::vector<float> GetSqrtHanningWindow(int frame_size, float scale) {
std::vector<float> window(frame_size);
WindowGenerator::Hanning(frame_size, window.data());
std::transform(window.begin(), window.end(), window.begin(),
[scale](float x) { return scale * std::sqrt(x); });
return window;
}
} // namespace
void TimeDomainFeatureExtractor::PushFeaturesToModelInput(
std::vector<float>& frame,
ArrayView<float> input) {
// Shift down overlap from previous frames.
std::copy(input.begin() + frame.size(), input.end(), input.begin());
std::transform(frame.begin(), frame.end(), input.end() - frame.size(),
[](float x) { return x * kScale; });
frame.clear();
}
FrequencyDomainFeatureExtractor::FrequencyDomainFeatureExtractor(int step_size)
: step_size_(step_size),
frame_size_(2 * step_size_),
sqrt_hanning_(GetSqrtHanningWindow(frame_size_, kScale)),
data_(static_cast<float*>(
pffft_aligned_malloc(frame_size_ * sizeof(float)))),
spectrum_(static_cast<float*>(
pffft_aligned_malloc(frame_size_ * sizeof(float)))),
pffft_setup_(pffft_new_setup(frame_size_, PFFFT_REAL)) {
std::memset(data_, 0, sizeof(float) * frame_size_);
std::memset(spectrum_, 0, sizeof(float) * frame_size_);
}
FrequencyDomainFeatureExtractor::~FrequencyDomainFeatureExtractor() {
pffft_destroy_setup(pffft_setup_);
pffft_aligned_free(spectrum_);
pffft_aligned_free(data_);
}
void FrequencyDomainFeatureExtractor::PushFeaturesToModelInput(
std::vector<float>& frame,
ArrayView<float> input) {
std::memcpy(data_ + step_size_, frame.data(), sizeof(float) * step_size_);
for (int k = 0; k < frame_size_; ++k) {
data_[k] *= sqrt_hanning_[k];
}
pffft_transform_ordered(pffft_setup_, data_, spectrum_, nullptr,
PFFFT_FORWARD);
RTC_CHECK_EQ(input.size(), step_size_ + 1);
input[0] = spectrum_[0] * spectrum_[0];
input[step_size_] = spectrum_[1] * spectrum_[1];
for (int k = 1; k < step_size_; k++) {
input[k] = spectrum_[2 * k] * spectrum_[2 * k] +
spectrum_[2 * k + 1] * spectrum_[2 * k + 1];
}
// Compress the power spectra.
std::transform(input.begin(), input.end(), input.begin(), [](float a) {
return std::pow(a, kSpectrumCompressionExponent);
});
// Saving the current frame as it is used when computing the next FFT.
std::memcpy(data_, frame.data(), sizeof(float) * step_size_);
}
} // namespace webrtc