// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <array>
#include <cmath>
#include <limits>
#include <tuple>
#include <vector>

#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/rnnoise/src/kiss_fft.h"

namespace rnnoise {
namespace test {
namespace {

const double kPi = std::acos(-1.0);

void FillFftInputBuffer(const size_t num_samples,
                        const float* samples,
                        std::complex<float>* input_buf) {
  for (size_t i = 0; i < num_samples; ++i)
    input_buf[i].real(samples[i]);
}

void CheckFftResult(const size_t num_fft_points,
                    const float* expected_real,
                    const float* expected_imag,
                    const std::complex<float>* computed,
                    const float tolerance) {
  for (size_t i = 0; i < num_fft_points; ++i) {
    SCOPED_TRACE(i);
    EXPECT_NEAR(expected_real[i], computed[i].real(), tolerance);
    EXPECT_NEAR(expected_imag[i], computed[i].imag(), tolerance);
  }
}

}  // namespace

class RnnVadTest
    : public testing::Test,
      public ::testing::WithParamInterface<std::tuple<size_t, float, float>> {};

// Check that IFFT(FFT(x)) == x (tolerating round-off errors).
TEST_P(RnnVadTest, KissFftForwardReverseCheckIdentity) {
  const auto params = GetParam();
  const float amplitude = std::get<0>(params);
  const size_t num_fft = std::get<1>(params);
  const float tolerance = std::get<2>(params);
  std::vector<float> samples;
  std::vector<float> zeros;
  samples.resize(num_fft);
  zeros.resize(num_fft);
  for (size_t i = 0; i < num_fft; ++i) {
    samples[i] = amplitude * std::sin(2.f * kPi * 10 * i / num_fft);
    zeros[i] = 0.f;
  }

  KissFft fft(num_fft);
  std::vector<std::complex<float>> fft_buf_1;
  fft_buf_1.resize(num_fft);
  std::vector<std::complex<float>> fft_buf_2;
  fft_buf_2.resize(num_fft);

  FillFftInputBuffer(samples.size(), samples.data(), fft_buf_1.data());
  {
    // TODO(alessiob): Underflow with non power of 2 frame sizes.
    // FloatingPointExceptionObserver fpe_observer;

    fft.ForwardFft(fft_buf_1.size(), fft_buf_1.data(), fft_buf_2.size(),
                   fft_buf_2.data());
    fft.ReverseFft(fft_buf_2.size(), fft_buf_2.data(), fft_buf_1.size(),
                   fft_buf_1.data());
  }
  CheckFftResult(samples.size(), samples.data(), zeros.data(), fft_buf_1.data(),
                 tolerance);
}

INSTANTIATE_TEST_SUITE_P(FftPoints,
                         RnnVadTest,
                         ::testing::Values(std::make_tuple(1.f, 240, 3e-7f),
                                           std::make_tuple(1.f, 256, 3e-7f),
                                           std::make_tuple(1.f, 480, 3e-7f),
                                           std::make_tuple(1.f, 512, 3e-7f),
                                           std::make_tuple(1.f, 960, 4e-7f),
                                           std::make_tuple(1.f, 1024, 3e-7f),
                                           std::make_tuple(30.f, 240, 5e-6f),
                                           std::make_tuple(30.f, 256, 5e-6f),
                                           std::make_tuple(30.f, 480, 6e-6f),
                                           std::make_tuple(30.f, 512, 6e-6f),
                                           std::make_tuple(30.f, 960, 8e-6f),
                                           std::make_tuple(30.f, 1024, 6e-6f)));

TEST(RnnVadTest, KissFftBitExactness) {
  constexpr std::array<float, 32> samples = {
      {0.3524301946163177490234375f,  0.891803801059722900390625f,
       0.07706542313098907470703125f, 0.699530780315399169921875f,
       0.3789891898632049560546875f,  0.5438187122344970703125f,
       0.332781612873077392578125f,   0.449340641498565673828125f,
       0.105229437351226806640625f,   0.722373783588409423828125f,
       0.13155306875705718994140625f, 0.340857982635498046875f,
       0.970204889774322509765625f,   0.53061950206756591796875f,
       0.91507828235626220703125f,    0.830274522304534912109375f,
       0.74468600749969482421875f,    0.24320767819881439208984375f,
       0.743998110294342041015625f,   0.17574800550937652587890625f,
       0.1834825575351715087890625f,  0.63317775726318359375f,
       0.11414264142513275146484375f, 0.1612723171710968017578125f,
       0.80316197872161865234375f,    0.4979794919490814208984375f,
       0.554282128810882568359375f,   0.67189347743988037109375f,
       0.06660757958889007568359375f, 0.89568817615509033203125f,
       0.29327380657196044921875f,    0.3472573757171630859375f}};
  constexpr std::array<float, 17> expected_real = {
      {0.4813065826892852783203125f, -0.0246877372264862060546875f,
       0.04095232486724853515625f, -0.0401695556938648223876953125f,
       0.00500857271254062652587890625f, 0.0160773508250713348388671875f,
       -0.011385642923414707183837890625f, -0.008461721241474151611328125f,
       0.01383177936077117919921875f, 0.0117270611226558685302734375f,
       -0.0164460353553295135498046875f, 0.0585579685866832733154296875f,
       0.02038039825856685638427734375f, -0.0209107734262943267822265625f,
       0.01046995259821414947509765625f, -0.09019653499126434326171875f,
       -0.0583711564540863037109375f}};
  constexpr std::array<float, 17> expected_imag = {
      {0.f, -0.010482530109584331512451171875f, 0.04762755334377288818359375f,
       -0.0558677613735198974609375f, 0.007908363826572895050048828125f,
       -0.0071932487189769744873046875f, 0.01322011835873126983642578125f,
       -0.011227893643081188201904296875f, -0.0400779247283935546875f,
       -0.0290451310575008392333984375f, 0.01519204117357730865478515625f,
       -0.09711246192455291748046875f, -0.00136523949913680553436279296875f,
       0.038602568209171295166015625f, -0.009693108499050140380859375f,
       -0.0183933563530445098876953125f, 0.f}};

  KissFft fft(32);
  std::array<std::complex<float>, 32> fft_buf_in;
  std::array<std::complex<float>, 32> fft_buf_out;
  FillFftInputBuffer(samples.size(), samples.data(), fft_buf_in.data());
  fft.ForwardFft(fft_buf_in.size(), fft_buf_in.data(), fft_buf_out.size(),
                 fft_buf_out.data());
  CheckFftResult(expected_real.size(), expected_real.data(),
                 expected_imag.data(), fft_buf_out.data(),
                 std::numeric_limits<float>::min());
}

}  // namespace test
}  // namespace rnnoise
