blob: 0344ddd4f07dfd2bd27ce412386a4bd9b413ebee [file] [log] [blame]
# Lint as: python2, python3
# Copyright (c) 2016 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Server side audio utilities functions for Brillo."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import logging
import numpy
import os
import struct
import subprocess
import tempfile
import wave
from autotest_lib.client.common_lib import error
from six.moves import map
from six.moves import range
_BITS_PER_BYTE=8
# Thresholds used when comparing files.
#
# The frequency threshold used when comparing files. The frequency of the
# recorded audio has to be within _FREQUENCY_THRESHOLD percent of the frequency
# of the original audio.
_FREQUENCY_THRESHOLD = 0.01
# Noise threshold controls how much noise is allowed as a fraction of the
# magnitude of the peak frequency after taking an FFT. The power of all the
# other frequencies in the signal should be within _FFT_NOISE_THRESHOLD percent
# of the power of the main frequency.
_FFT_NOISE_THRESHOLD = 0.05
# Command used to encode audio. If you want to test with something different,
# this should be changed.
_ENCODING_CMD = 'sox'
def extract_wav_frames(wave_file):
"""Extract all frames from a WAV file.
@param wave_file: A Wave_read object representing a WAV file opened for
reading.
@return: A list containing the frames in the WAV file.
"""
num_frames = wave_file.getnframes()
sample_width = wave_file.getsampwidth()
if sample_width == 1:
fmt = '%iB' # Read 1 byte.
elif sample_width == 2:
fmt = '%ih' # Read 2 bytes.
elif sample_width == 4:
fmt = '%ii' # Read 4 bytes.
else:
raise ValueError('Unsupported sample width')
frames = list(struct.unpack(fmt % num_frames * wave_file.getnchannels(),
wave_file.readframes(num_frames)))
# Since 8-bit PCM is unsigned with an offset of 128, we subtract the offset
# to make it signed since the rest of the code assumes signed numbers.
if sample_width == 1:
frames = [val - 128 for val in frames]
return frames
def check_wav_file(filename, num_channels=None, sample_rate=None,
sample_width=None):
"""Checks a WAV file and returns its peak PCM values.
@param filename: Input WAV file to analyze.
@param num_channels: Number of channels to expect (None to not check).
@param sample_rate: Sample rate to expect (None to not check).
@param sample_width: Sample width to expect (None to not check).
@return A list of the absolute maximum PCM values for each channel in the
WAV file.
@raise ValueError: Failed to process the WAV file or validate an attribute.
"""
chk_file = None
try:
chk_file = wave.open(filename, 'r')
if num_channels is not None and chk_file.getnchannels() != num_channels:
raise ValueError('Expected %d channels but got %d instead.',
num_channels, chk_file.getnchannels())
if sample_rate is not None and chk_file.getframerate() != sample_rate:
raise ValueError('Expected sample rate %d but got %d instead.',
sample_rate, chk_file.getframerate())
if sample_width is not None and chk_file.getsampwidth() != sample_width:
raise ValueError('Expected sample width %d but got %d instead.',
sample_width, chk_file.getsampwidth())
frames = extract_wav_frames(chk_file)
except wave.Error as e:
raise ValueError('Error processing WAV file: %s' % e)
finally:
if chk_file is not None:
chk_file.close()
peaks = []
for i in range(chk_file.getnchannels()):
peaks.append(max(list(map(abs, frames[i::chk_file.getnchannels()]))))
return peaks;
def generate_sine_file(host, num_channels, sample_rate, sample_width,
duration_secs, sine_frequency, temp_dir,
file_format='wav'):
"""Generate a sine file and push it to the DUT.
@param host: An object representing the DUT.
@param num_channels: Number of channels to use.
@param sample_rate: Sample rate to use for sine wave generation.
@param sample_width: Sample width to use for sine wave generation.
@param duration_secs: Duration in seconds to generate sine wave for.
@param sine_frequency: Frequency to generate sine wave at.
@param temp_dir: A temporary directory on the host.
@param file_format: A string representing the encoding for the audio file.
@return A tuple of the filename on the server and the DUT.
""";
_, local_filename = tempfile.mkstemp(
prefix='sine-', suffix='.' + file_format, dir=temp_dir)
if sample_width == 1:
byte_format = '-e unsigned'
else:
byte_format = '-e signed'
gen_file_cmd = ('sox -n -t wav -c %d %s -b %d -r %d %s synth %d sine %d '
'vol 0.9' % (num_channels, byte_format,
sample_width * _BITS_PER_BYTE, sample_rate,
local_filename, duration_secs, sine_frequency))
logging.info('Command to generate sine wave: %s', gen_file_cmd)
subprocess.call(gen_file_cmd, shell=True)
if file_format != 'wav':
# Convert the file to the appropriate format.
logging.info('Converting file to %s', file_format)
_, local_encoded_filename = tempfile.mkstemp(
prefix='sine-', suffix='.' + file_format, dir=temp_dir)
cvt_file_cmd = '%s %s %s' % (_ENCODING_CMD, local_filename,
local_encoded_filename)
logging.info('Command to convert file: %s', cvt_file_cmd)
subprocess.call(cvt_file_cmd, shell=True)
else:
local_encoded_filename = local_filename
dut_tmp_dir = '/data'
remote_filename = os.path.join(dut_tmp_dir, 'sine.' + file_format)
logging.info('Send file to DUT.')
# TODO(ralphnathan): Find a better place to put this file once the SELinux
# issues are resolved.
logging.info('remote_filename %s', remote_filename)
host.send_file(local_encoded_filename, remote_filename)
return local_filename, remote_filename
def _is_outside_frequency_threshold(freq_reference, freq_rec):
"""Compares the frequency of the recorded audio with the reference audio.
This function checks to see if the frequencies corresponding to the peak
FFT values are similiar meaning that the dominant frequency in the audio
signal is the same for the recorded audio as that in the audio played.
@param req_reference: The dominant frequency in the reference audio file.
@param freq_rec: The dominant frequency in the recorded audio file.
@return: True is freq_rec is with _FREQUENCY_THRESHOLD percent of
freq_reference.
"""
ratio = float(freq_rec) / freq_reference
if ratio > 1 + _FREQUENCY_THRESHOLD or ratio < 1 - _FREQUENCY_THRESHOLD:
return True
return False
def _compare_frames(reference_file_frames, rec_file_frames, num_channels,
sample_rate):
"""Compares audio frames from the reference file and the recorded file.
This method checks for two things:
1. That the main frequency is the same in both the files. This is done
using the FFT and observing the frequency corresponding to the
peak.
2. That there is no other dominant frequency in the recorded file.
This is done by sweeping the frequency domain and checking that the
frequency is always less than _FFT_NOISE_THRESHOLD percentage of
the peak.
The key assumption here is that the reference audio file contains only
one frequency.
@param reference_file_frames: Audio frames from the reference file.
@param rec_file_frames: Audio frames from the recorded file.
@param num_channels: Number of channels in the files.
@param sample_rate: Sample rate of the files.
@raise error.TestFail: The frequency of the recorded signal doesn't
match that of the reference signal.
@raise error.TestFail: There is too much noise in the recorded signal.
"""
for channel in range(num_channels):
reference_data = reference_file_frames[channel::num_channels]
rec_data = rec_file_frames[channel::num_channels]
# Get fft and frequencies corresponding to the fft values.
fft_reference = numpy.fft.rfft(reference_data)
fft_rec = numpy.fft.rfft(rec_data)
fft_freqs_reference = numpy.fft.rfftfreq(len(reference_data),
1.0 / sample_rate)
fft_freqs_rec = numpy.fft.rfftfreq(len(rec_data), 1.0 / sample_rate)
# Get frequency at highest peak.
freq_reference = fft_freqs_reference[
numpy.argmax(numpy.abs(fft_reference))]
abs_fft_rec = numpy.abs(fft_rec)
freq_rec = fft_freqs_rec[numpy.argmax(abs_fft_rec)]
# Compare the two frequencies.
logging.info('Golden frequency of channel %i is %f', channel,
freq_reference)
logging.info('Recorded frequency of channel %i is %f', channel,
freq_rec)
if _is_outside_frequency_threshold(freq_reference, freq_rec):
raise error.TestFail('The recorded audio frequency does not match '
'that of the audio played.')
# Check for noise in the frequency domain.
fft_rec_peak_val = numpy.max(abs_fft_rec)
noise_detected = False
for fft_index, fft_val in enumerate(abs_fft_rec):
if _is_outside_frequency_threshold(freq_reference, freq_rec):
# If the frequency exceeds _FFT_NOISE_THRESHOLD, then fail.
if fft_val > _FFT_NOISE_THRESHOLD * fft_rec_peak_val:
logging.warning('Unexpected frequency peak detected at %f '
'Hz.', fft_freqs_rec[fft_index])
noise_detected = True
if noise_detected:
raise error.TestFail('Signal is noiser than expected.')
def compare_file(reference_audio_filename, test_audio_filename):
"""Compares the recorded audio file to the reference audio file.
@param reference_audio_filename : Reference audio file containing the
reference signal.
@param test_audio_filename: Audio file containing audio captured from
the test.
"""
with contextlib.closing(wave.open(reference_audio_filename,
'rb')) as reference_file:
with contextlib.closing(wave.open(test_audio_filename,
'rb')) as rec_file:
# Extract data from files.
reference_file_frames = extract_wav_frames(reference_file)
rec_file_frames = extract_wav_frames(rec_file)
num_channels = reference_file.getnchannels()
_compare_frames(reference_file_frames, rec_file_frames,
reference_file.getnchannels(),
reference_file.getframerate())