blob: 431c01fab3941f828be8710d5d09be181e29d4d2 [file] [log] [blame]
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include <algorithm>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");
} // namespace
AutoCorrelationCalculator::AutoCorrelationCalculator()
: fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
tmp_(fft_.CreateBuffer()),
X_(fft_.CreateBuffer()),
H_(fft_.CreateBuffer()) {}
AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// The auto-correlations coefficients are computed as follows:
// |.........|...........| <- pitch buffer
// [ x (fixed) ]
// [ y_0 ]
// [ y_{m-1} ]
// x and y are sub-array of equal length; x is never moved, whereas y slides.
// The cross-correlation between y_0 and x corresponds to the auto-correlation
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
// Compute the FFT for the reversed reference frame - i.e.,
// pitch_buf[-kConvolutionLength:].
std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
tmp.begin());
std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
// Convolve in the frequency domain.
constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
std::fill(tmp.begin(), tmp.end(), 0.f);
fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
auto_corr.begin());
}
} // namespace rnn_vad
} // namespace webrtc