blob: 261ad631cbb32f48deca6cebe48d59667b28ae81 [file] [log] [blame]
// Copyright 2019 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::error;
use std::fmt;
use std::fs::File;
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::os::raw::c_int;
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use audio_streams::{SampleFormat, StreamSource};
use hound::{WavReader, WavSpec, WavWriter};
use libcras::{BoxError, CrasClient, CrasNodeType};
use sys_util::{register_signal_handler, set_rt_prio_limit, set_rt_round_robin};
use crate::arguments::{AudioOptions, FileType, LoopbackType};
#[derive(Debug)]
pub enum Error {
CreateStream(BoxError),
FetchStream(BoxError),
FloatingPointSamples,
InvalidWavFile(hound::Error),
Io(io::Error),
Libcras(libcras::Error),
NoLoopbackNode(CrasNodeType),
OpenFile(hound::Error),
SampleBits(u16),
SysUtil(sys_util::Error),
}
impl error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use Error::*;
match self {
CreateStream(e) => write!(f, "Failed to create stream: {}", e),
FetchStream(e) => write!(f, "Failed to fetch buffer from stream: {}", e),
FloatingPointSamples => write!(f, "Floating point audio samples are not supported"),
InvalidWavFile(e) => write!(f, "Could not open file as WAV file: {}", e),
Io(e) => write!(f, "IO Error: {}", e),
Libcras(e) => write!(f, "Libcras Error: {}", e),
NoLoopbackNode(typ) => write!(f, "No loopback node found with type {:?}", typ),
OpenFile(e) => write!(f, "Could not open WAV file for writing: {}", e),
SampleBits(bits) => write!(
f,
"Sample size {} is not supported, only 8, 16, 24, and 32 bit samples are supported",
bits
),
SysUtil(e) => write!(f, "SysUtil Error: {}", e),
}
}
}
type Result<T> = std::result::Result<T, Error>;
static INTERRUPTED: AtomicBool = AtomicBool::new(false);
extern "C" fn sigint_handler() {
// Check if we've already received one SIGINT. If we have, the program may
// be misbehaving and not terminating, so to be safe we'll forcefully exit.
if INTERRUPTED.load(Ordering::Acquire) {
std::process::exit(1);
}
INTERRUPTED.store(true, Ordering::Release);
}
fn add_sigint_handler() -> Result<()> {
const SIGINT: c_int = 2;
let result = unsafe { register_signal_handler(SIGINT, sigint_handler) };
result.map_err(Error::SysUtil)
}
fn set_priority_to_realtime() {
const AUDIO_THREAD_RTPRIO: u16 = 10;
if set_rt_prio_limit(AUDIO_THREAD_RTPRIO as u64).is_err()
|| set_rt_round_robin(AUDIO_THREAD_RTPRIO as i32).is_err()
{
println!("Attempt to use real-time priority failed, running with default scheduler.");
}
}
fn channel_string(num_channels: usize) -> String {
match num_channels {
1 => "Mono".to_string(),
2 => "Stereo".to_string(),
_ => format!("{} Channels", num_channels),
}
}
struct WavSource {
wav_reader: WavReader<BufReader<File>>,
format: SampleFormat,
num_channels: usize,
frame_rate: u32,
}
impl WavSource {
fn try_new(opts: &AudioOptions) -> Result<Self> {
let wav_reader = WavReader::open(&opts.file_name).map_err(Error::InvalidWavFile)?;
let spec = wav_reader.spec();
if spec.sample_format == hound::SampleFormat::Float {
return Err(Error::FloatingPointSamples);
}
let format = match spec.bits_per_sample {
8 => SampleFormat::U8,
16 => SampleFormat::S16LE,
24 => SampleFormat::S24LE,
32 => SampleFormat::S32LE,
s => return Err(Error::SampleBits(s)),
};
if opts.format.is_some() && Some(format) != opts.format {
eprintln!("Warning: format changed to {:?}", format);
}
let num_channels = spec.channels as usize;
if opts.num_channels.is_some() && Some(num_channels) != opts.num_channels {
eprintln!("Warning: number of channels changed to {}", num_channels);
}
let frame_rate = spec.sample_rate;
if opts.frame_rate.is_some() && Some(frame_rate) != opts.frame_rate {
eprintln!("Warning: frame rate changed to {}", frame_rate);
}
Ok(Self {
wav_reader,
format,
num_channels,
frame_rate,
})
}
fn format(&self) -> SampleFormat {
self.format
}
fn num_channels(&self) -> usize {
self.num_channels
}
fn frame_rate(&self) -> u32 {
self.frame_rate
}
}
impl Read for WavSource {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let frame_size = self.format.sample_bytes() * self.num_channels;
let read_len = buf.len() - buf.len() % frame_size;
let num_samples = read_len / self.format.sample_bytes();
let samples = self.wav_reader.samples::<i32>();
let mut read = 0;
for s in samples.take(num_samples) {
match s {
Ok(sample) => {
let result = match self.format {
SampleFormat::U8 => buf.write_all(&((sample + 128) as u8).to_le_bytes()),
SampleFormat::S16LE => buf.write_all(&(sample as i16).to_le_bytes()),
SampleFormat::S24LE | SampleFormat::S32LE => {
buf.write_all(&sample.to_le_bytes())
}
};
match result {
Ok(()) => read += self.format.sample_bytes(),
Err(_) => return Ok(read),
};
}
Err(_) => return Ok(read),
};
}
Ok(read)
}
}
pub fn playback(opts: AudioOptions) -> Result<()> {
let num_channels;
let frame_rate;
let format;
let mut sample_source: Box<dyn Read> = match opts.file_type {
FileType::Wav => {
let wav_source = WavSource::try_new(&opts)?;
num_channels = wav_source.num_channels();
frame_rate = wav_source.frame_rate();
format = wav_source.format();
Box::new(wav_source)
}
FileType::Raw => {
num_channels = opts.num_channels.unwrap_or(2);
frame_rate = opts.frame_rate.unwrap_or(48000);
format = opts.format.unwrap_or(SampleFormat::S16LE);
Box::new(BufReader::new(
File::open(&opts.file_name).map_err(Error::Io)?,
))
}
};
println!(
"Playing {} '{}' : {}, Rate {} Hz, {}",
opts.file_type,
opts.file_name.display(),
format,
frame_rate,
channel_string(num_channels)
);
let mut cras_client = CrasClient::new().map_err(Error::Libcras)?;
let (_control, mut stream) = cras_client
.new_playback_stream(
num_channels,
format,
frame_rate,
opts.buffer_size.unwrap_or(256),
)
.map_err(Error::CreateStream)?;
set_priority_to_realtime();
add_sigint_handler()?;
while !INTERRUPTED.load(Ordering::Acquire) {
let mut buffer = stream.next_playback_buffer().map_err(Error::FetchStream)?;
let frame_size = num_channels * format.sample_bytes();
let frames = buffer.frame_capacity();
let mut chunk = (&mut sample_source).take((frames * frame_size) as u64);
let transferred = io::copy(&mut chunk, &mut buffer).map_err(Error::Io)?;
if transferred == 0 {
break;
}
}
// Stream and client should gracefully be closed out of this scope
Ok(())
}
struct WavSink {
wav_writer: WavWriter<BufWriter<File>>,
format: SampleFormat,
}
impl WavSink {
fn try_new<P: AsRef<Path>>(
path: P,
num_channels: usize,
format: SampleFormat,
frame_rate: u32,
) -> Result<Self> {
let spec = WavSpec {
channels: num_channels as u16,
sample_rate: frame_rate,
bits_per_sample: (format.sample_bytes() * 8) as u16,
sample_format: hound::SampleFormat::Int,
};
let wav_writer = WavWriter::create(path, spec).map_err(Error::OpenFile)?;
Ok(Self { wav_writer, format })
}
}
impl Write for WavSink {
fn write(&mut self, samples: &[u8]) -> io::Result<usize> {
let sample_bytes = self.format.sample_bytes();
if samples.len() % sample_bytes != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"u8 samples vector of length {} cannot be interpreted as {:?} samples",
samples.len(),
self.format
),
));
}
let num_samples = samples.len() / sample_bytes;
match self.format {
SampleFormat::U8 => {
for sample in samples {
self.wav_writer.write_sample(*sample as i8).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed to write sample: {}", e),
)
})?;
}
}
SampleFormat::S16LE => {
// hound offers an optimized i16 writer, so special case here.
let mut writer = self.wav_writer.get_i16_writer(num_samples as u32);
for i in 0..num_samples {
let sample = i16::from_le_bytes([
samples[sample_bytes * i],
samples[sample_bytes * i + 1],
]);
writer.write_sample(sample);
}
// I16Writer buffers internally and must be explicitly flushed to write
// samples to the backing writer. Flush is not called automatically
// on drop.
// The flush method only writes data from the i16_writer to the underlying
// WavWriter, it does not actually guarantee a flush to disk.
writer.flush().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed to flush SampleWriter: {}", e),
)
})?;
}
SampleFormat::S24LE | SampleFormat::S32LE => {
for i in 0..num_samples {
let mut sample = i32::from_le_bytes([
samples[sample_bytes * i],
samples[sample_bytes * i + 1],
samples[sample_bytes * i + 2],
samples[sample_bytes * i + 3],
]);
// Upsample to 32 bit since CRAS doesn't support S24_3LE,
// even though the wav encoder does.
// TODO(fletcherw): add S24_LE support to hound.
if self.format == SampleFormat::S24LE {
sample <<= 8;
}
self.wav_writer.write_sample(sample).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed to write sample: {}", e),
)
})?;
}
}
}
Ok(samples.len())
}
fn flush(&mut self) -> io::Result<()> {
self.wav_writer.flush().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed to flush WavWriter: {}", e),
)
})
}
}
pub fn capture(opts: AudioOptions) -> Result<()> {
let num_channels = opts.num_channels.unwrap_or(2);
let format = opts.format.unwrap_or(SampleFormat::S16LE);
let frame_rate = opts.frame_rate.unwrap_or(48000);
let buffer_size = opts.buffer_size.unwrap_or(256);
let mut sample_sink: Box<dyn Write> = match opts.file_type {
FileType::Raw => Box::new(BufWriter::new(
File::create(&opts.file_name).map_err(Error::Io)?,
)),
FileType::Wav => Box::new(WavSink::try_new(
&opts.file_name,
num_channels,
format,
frame_rate,
)?),
};
println!(
"Recording {} '{}' : {}, Rate {} Hz, {}",
opts.file_type,
opts.file_name.display(),
format,
frame_rate,
channel_string(num_channels)
);
let mut cras_client = CrasClient::new().map_err(Error::Libcras)?;
cras_client.enable_cras_capture();
let (_control, mut stream) = match opts.loopback_type {
Some(loopback_type) => {
let node_type = match loopback_type {
LoopbackType::PreDsp => CrasNodeType::CRAS_NODE_TYPE_POST_MIX_PRE_DSP,
LoopbackType::PostDsp => CrasNodeType::CRAS_NODE_TYPE_POST_DSP,
};
let loopback_node = cras_client
.input_nodes()
.find(|node| node.node_type == node_type)
.ok_or(Error::NoLoopbackNode(node_type))?;
cras_client
.new_pinned_capture_stream(
loopback_node.iodev_index,
num_channels,
format,
frame_rate,
buffer_size,
)
.map_err(Error::CreateStream)?
}
None => cras_client
.new_capture_stream(num_channels, format, frame_rate, buffer_size)
.map_err(Error::CreateStream)?,
};
set_priority_to_realtime();
add_sigint_handler()?;
while !INTERRUPTED.load(Ordering::Acquire) {
let mut buf = stream.next_capture_buffer().map_err(Error::FetchStream)?;
io::copy(&mut buf, &mut sample_sink).map_err(Error::Io)?;
}
Ok(())
}