blob: 98da39e38ada5406e92d38f4f4793a1684a20311 [file] [log] [blame]
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include <cmath>
#include <vector>
#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int ceil(int n, int m) {
return (n + m - 1) / m;
}
// Number of 10 ms frames required to fill a pitch buffer having size
// |kBufSize24kHz|.
constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
// Number of samples for the test data.
constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
// Verifies that the pitch in Hz is in the detectable range.
bool PitchIsValid(float pitch_hz) {
const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
return kInitialMinPitch24kHz <= pitch_period &&
pitch_period <= kMaxPitch24kHz;
}
void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
}
}
// Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
// For every frame, the output is written into |feature_vector|. Returns true
// if silence is detected in the last frame.
bool FeedTestData(FeaturesExtractor& features_extractor,
rtc::ArrayView<const float> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
bool is_silence = true;
const int num_frames = samples.size() / kFrameSize10ms24kHz;
for (int i = 0; i < num_frames; ++i) {
is_silence = features_extractor.CheckSilenceComputeFeatures(
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
feature_vector);
}
return is_silence;
}
// Extracts the features for two pure tones and verifies that the pitch field
// values reflect the known tone frequencies.
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
constexpr float amplitude = 1000.f;
constexpr float low_pitch_hz = 150.f;
constexpr float high_pitch_hz = 250.f;
ASSERT_TRUE(PitchIsValid(low_pitch_hz));
ASSERT_TRUE(PitchIsValid(high_pitch_hz));
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
FeaturesExtractor features_extractor(cpu_features);
std::vector<float> samples(kNumTestDataSize);
std::vector<float> feature_vector(kFeatureVectorSize);
ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
feature_vector.data(), kFeatureVectorSize);
// Extract the normalized scalar feature that is proportional to the estimated
// pitch period.
constexpr int pitch_feature_index = kFeatureVectorSize - 2;
// Low frequency tone - i.e., high period.
CreatePureTone(amplitude, low_pitch_hz, samples);
ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
float high_pitch_period = feature_vector_view[pitch_feature_index];
// High frequency tone - i.e., low period.
features_extractor.Reset();
CreatePureTone(amplitude, high_pitch_hz, samples);
ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
float low_pitch_period = feature_vector_view[pitch_feature_index];
// Check.
EXPECT_LT(low_pitch_period, high_pitch_period);
}
} // namespace
} // namespace rnn_vad
} // namespace webrtc