/*
 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"

#include <algorithm>
#include <cmath>
#include <limits>
#include <numeric>

#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include "rtc_base/checks.h"

namespace webrtc {
namespace rnn_vad {
namespace {

constexpr float kSilenceThreshold = 0.04f;

// Computes the new spectral difference stats and pushes them into the passed
// symmetric matrix buffer.
void UpdateSpectralDifferenceStats(
    rtc::ArrayView<const float, kNumBands> new_spectral_coeffs,
    const RingBuffer<float, kNumBands, kSpectralCoeffsHistorySize>& ring_buf,
    SymmetricMatrixBuffer<float, kSpectralCoeffsHistorySize>* sym_matrix_buf) {
  RTC_DCHECK(sym_matrix_buf);
  // Compute the new spectral distance stats.
  std::array<float, kSpectralCoeffsHistorySize - 1> distances;
  for (size_t i = 0; i < kSpectralCoeffsHistorySize - 1; ++i) {
    const size_t delay = i + 1;
    auto old_spectral_coeffs = ring_buf.GetArrayView(delay);
    distances[i] = 0.f;
    for (size_t k = 0; k < kNumBands; ++k) {
      const float c = new_spectral_coeffs[k] - old_spectral_coeffs[k];
      distances[i] += c * c;
    }
  }
  // Push the new spectral distance stats into the symmetric matrix buffer.
  sym_matrix_buf->Push(distances);
}

}  // namespace

SpectralFeaturesView::SpectralFeaturesView(
    rtc::ArrayView<float, kNumBands - kNumLowerBands> coeffs,
    rtc::ArrayView<float, kNumLowerBands> average,
    rtc::ArrayView<float, kNumLowerBands> first_derivative,
    rtc::ArrayView<float, kNumLowerBands> second_derivative,
    rtc::ArrayView<float, kNumLowerBands> cross_correlations,
    float* variability)
    : coeffs(coeffs),
      average(average),
      first_derivative(first_derivative),
      second_derivative(second_derivative),
      cross_correlations(cross_correlations),
      variability(variability) {}

SpectralFeaturesView::SpectralFeaturesView(const SpectralFeaturesView&) =
    default;
SpectralFeaturesView::~SpectralFeaturesView() = default;

SpectralFeaturesExtractor::SpectralFeaturesExtractor()
    : fft_(),
      reference_frame_fft_(kFrameSize20ms24kHz / 2 + 1),
      lagged_frame_fft_(kFrameSize20ms24kHz / 2 + 1),
      band_boundaries_(
          ComputeBandBoundaryIndexes(kSampleRate24kHz, kFrameSize20ms24kHz)),
      dct_table_(ComputeDctTable()) {}

SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;

void SpectralFeaturesExtractor::Reset() {
  spectral_coeffs_ring_buf_.Reset();
  spectral_diffs_buf_.Reset();
}

bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
    rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
    rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
    SpectralFeaturesView spectral_features) {
  // Analyze reference frame.
  fft_.ForwardFft(reference_frame, reference_frame_fft_);
  ComputeBandEnergies(reference_frame_fft_, band_boundaries_,
                      reference_frame_energy_coeffs_);
  // Check if the reference frame has silence.
  const float tot_energy =
      std::accumulate(reference_frame_energy_coeffs_.begin(),
                      reference_frame_energy_coeffs_.end(), 0.f);
  if (tot_energy < kSilenceThreshold)
    return true;
  // Analyze lagged frame.
  fft_.ForwardFft(lagged_frame, lagged_frame_fft_);
  ComputeBandEnergies(lagged_frame_fft_, band_boundaries_,
                      lagged_frame_energy_coeffs_);
  // Log of the band energies for the reference frame.
  std::array<float, kNumBands> log_band_energy_coeffs;
  ComputeLogBandEnergiesCoefficients(reference_frame_energy_coeffs_,
                                     log_band_energy_coeffs);
  // Decorrelate band-wise log energy coefficients via DCT.
  std::array<float, kNumBands> log_band_energy_coeffs_decorrelated;
  ComputeDct(log_band_energy_coeffs, dct_table_,
             log_band_energy_coeffs_decorrelated);
  // Normalize (based on training set stats).
  log_band_energy_coeffs_decorrelated[0] -= 12;
  log_band_energy_coeffs_decorrelated[1] -= 4;
  // Update the ring buffer and the spectral difference stats.
  spectral_coeffs_ring_buf_.Push(log_band_energy_coeffs_decorrelated);
  UpdateSpectralDifferenceStats(log_band_energy_coeffs_decorrelated,
                                spectral_coeffs_ring_buf_,
                                &spectral_diffs_buf_);
  // Write the higher bands spectral coefficients.
  auto coeffs_src = spectral_coeffs_ring_buf_.GetArrayView(0);
  RTC_DCHECK_EQ(coeffs_src.size() - kNumLowerBands,
                spectral_features.coeffs.size());
  std::copy(coeffs_src.begin() + kNumLowerBands, coeffs_src.end(),
            spectral_features.coeffs.begin());
  // Compute and write remaining features.
  ComputeAvgAndDerivatives(spectral_features.average,
                           spectral_features.first_derivative,
                           spectral_features.second_derivative);
  ComputeCrossCorrelation(spectral_features.cross_correlations);
  RTC_DCHECK(spectral_features.variability);
  *(spectral_features.variability) = ComputeVariability();
  return false;
}

void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
    rtc::ArrayView<float, kNumLowerBands> average,
    rtc::ArrayView<float, kNumLowerBands> first_derivative,
    rtc::ArrayView<float, kNumLowerBands> second_derivative) {
  auto curr = spectral_coeffs_ring_buf_.GetArrayView(0);
  auto prev1 = spectral_coeffs_ring_buf_.GetArrayView(1);
  auto prev2 = spectral_coeffs_ring_buf_.GetArrayView(2);
  RTC_DCHECK_EQ(average.size(), first_derivative.size());
  RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
  RTC_DCHECK_LE(average.size(), curr.size());
  for (size_t i = 0; i < average.size(); ++i) {
    // Average, kernel: [1, 1, 1].
    average[i] = curr[i] + prev1[i] + prev2[i];
    // First derivative, kernel: [1, 0, - 1].
    first_derivative[i] = curr[i] - prev2[i];
    // Second derivative, Laplacian kernel: [1, -2, 1].
    second_derivative[i] = curr[i] - 2 * prev1[i] + prev2[i];
  }
}

void SpectralFeaturesExtractor::ComputeCrossCorrelation(
    rtc::ArrayView<float, kNumLowerBands> cross_correlations) {
  const auto& x = reference_frame_fft_;
  const auto& y = lagged_frame_fft_;
  auto cross_corr = [x, y](const size_t freq_bin_index) -> float {
    return (x[freq_bin_index].real() * y[freq_bin_index].real() +
            x[freq_bin_index].imag() * y[freq_bin_index].imag());
  };
  std::array<float, kNumBands> cross_corr_coeffs;
  constexpr size_t kNumFftPoints = kFrameSize20ms24kHz / 2 + 1;
  ComputeBandCoefficients(cross_corr, band_boundaries_, kNumFftPoints - 1,
                          cross_corr_coeffs);
  // Normalize.
  for (size_t i = 0; i < cross_corr_coeffs.size(); ++i) {
    cross_corr_coeffs[i] =
        cross_corr_coeffs[i] /
        std::sqrt(0.001f + reference_frame_energy_coeffs_[i] *
                               lagged_frame_energy_coeffs_[i]);
  }
  // Decorrelate.
  ComputeDct(cross_corr_coeffs, dct_table_, cross_correlations);
  // Normalize (based on training set stats).
  cross_correlations[0] -= 1.3f;
  cross_correlations[1] -= 0.9f;
}

float SpectralFeaturesExtractor::ComputeVariability() {
  // Compute spectral variability score.
  float spec_variability = 0.f;
  for (size_t delay1 = 0; delay1 < kSpectralCoeffsHistorySize; ++delay1) {
    float min_dist = std::numeric_limits<float>::max();
    for (size_t delay2 = 0; delay2 < kSpectralCoeffsHistorySize; ++delay2) {
      if (delay1 == delay2)  // The distance would be 0.
        continue;
      min_dist =
          std::min(min_dist, spectral_diffs_buf_.GetValue(delay1, delay2));
    }
    spec_variability += min_dist;
  }
  // Normalize (based on training set stats).
  return spec_variability / kSpectralCoeffsHistorySize - 2.1f;
}

}  // namespace rnn_vad
}  // namespace webrtc
