/*
 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"

#include <cmath>
#include <vector>

#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"

namespace webrtc {
namespace rnn_vad {
namespace {

constexpr int ceil(int n, int m) {
  return (n + m - 1) / m;
}

// Number of 10 ms frames required to fill a pitch buffer having size
// |kBufSize24kHz|.
constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
// Number of samples for the test data.
constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;

// Verifies that the pitch in Hz is in the detectable range.
bool PitchIsValid(float pitch_hz) {
  const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
  return kInitialMinPitch24kHz <= pitch_period &&
         pitch_period <= kMaxPitch24kHz;
}

void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
  for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
    dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
  }
}

// Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
// For every frame, the output is written into |feature_vector|. Returns true
// if silence is detected in the last frame.
bool FeedTestData(FeaturesExtractor& features_extractor,
                  rtc::ArrayView<const float> samples,
                  rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
  // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
  // FloatingPointExceptionObserver fpe_observer;
  bool is_silence = true;
  const int num_frames = samples.size() / kFrameSize10ms24kHz;
  for (int i = 0; i < num_frames; ++i) {
    is_silence = features_extractor.CheckSilenceComputeFeatures(
        {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
        feature_vector);
  }
  return is_silence;
}

// Extracts the features for two pure tones and verifies that the pitch field
// values reflect the known tone frequencies.
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
  constexpr float amplitude = 1000.f;
  constexpr float low_pitch_hz = 150.f;
  constexpr float high_pitch_hz = 250.f;
  ASSERT_TRUE(PitchIsValid(low_pitch_hz));
  ASSERT_TRUE(PitchIsValid(high_pitch_hz));

  const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
  FeaturesExtractor features_extractor(cpu_features);
  std::vector<float> samples(kNumTestDataSize);
  std::vector<float> feature_vector(kFeatureVectorSize);
  ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
  rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
      feature_vector.data(), kFeatureVectorSize);

  // Extract the normalized scalar feature that is proportional to the estimated
  // pitch period.
  constexpr int pitch_feature_index = kFeatureVectorSize - 2;
  // Low frequency tone - i.e., high period.
  CreatePureTone(amplitude, low_pitch_hz, samples);
  ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
  float high_pitch_period = feature_vector_view[pitch_feature_index];
  // High frequency tone - i.e., low period.
  features_extractor.Reset();
  CreatePureTone(amplitude, high_pitch_hz, samples);
  ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
  float low_pitch_period = feature_vector_view[pitch_feature_index];
  // Check.
  EXPECT_LT(low_pitch_period, high_pitch_period);
}

}  // namespace
}  // namespace rnn_vad
}  // namespace webrtc
