Implement Autocorrelation
This commit is contained in:
95
rust_native_module/src/beat_tracking/audio/capture.rs
Normal file
95
rust_native_module/src/beat_tracking/audio/capture.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use psimple::Simple;
|
||||
use pulse::context::{Context, FlagSet as ContextFlagSet};
|
||||
use pulse::mainloop::standard::{IterateResult, Mainloop};
|
||||
use pulse::sample::Spec;
|
||||
use pulse::stream::Direction;
|
||||
use std::cell::RefCell;
|
||||
use std::ops::Deref;
|
||||
use std::rc::Rc;
|
||||
|
||||
/*
|
||||
Some manual poking around in PulseAudio to get the name of the default sink
|
||||
*/
|
||||
|
||||
fn poll_mainloop(mainloop: &Rc<RefCell<Mainloop>>) -> Result<()> {
|
||||
match mainloop.borrow_mut().iterate(true) {
|
||||
IterateResult::Quit(_) | IterateResult::Err(_) => {
|
||||
return Err(anyhow!("Iterate state was not success"));
|
||||
}
|
||||
IterateResult::Success(_) => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pulse_default_sink() -> Result<String> {
|
||||
let mainloop = Rc::new(RefCell::new(
|
||||
Mainloop::new().ok_or(anyhow!("Failed to create mainloop"))?,
|
||||
));
|
||||
let ctx = Rc::new(RefCell::new(
|
||||
Context::new(mainloop.borrow().deref(), "gib_default_sink")
|
||||
.ok_or(anyhow!("Failed to create context"))?,
|
||||
));
|
||||
|
||||
ctx.borrow_mut()
|
||||
.connect(None, ContextFlagSet::NOFLAGS, None)?;
|
||||
|
||||
// Wait for context to be ready
|
||||
loop {
|
||||
poll_mainloop(&mainloop)?;
|
||||
|
||||
match ctx.borrow().get_state() {
|
||||
pulse::context::State::Ready => {
|
||||
break;
|
||||
}
|
||||
pulse::context::State::Failed | pulse::context::State::Terminated => {
|
||||
return Err(anyhow!("Context was in failed/terminated state"));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let result = Rc::new(RefCell::new(None));
|
||||
let cb_result_ref = result.clone();
|
||||
|
||||
ctx.borrow().introspect().get_server_info(move |info| {
|
||||
*cb_result_ref.borrow_mut() = if let Some(ref sink_name) = info.default_sink_name {
|
||||
Some(Ok(sink_name.to_string()))
|
||||
} else {
|
||||
Some(Err(()))
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
if let Some(result) = result.borrow().deref() {
|
||||
return result
|
||||
.to_owned()
|
||||
.map_err(|_| anyhow!("Default sink name was empty"));
|
||||
}
|
||||
|
||||
poll_mainloop(&mainloop)?;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Get a PASimple instance which reads from the default sink monitor
|
||||
*/
|
||||
|
||||
pub fn get_audio_reader(spec: &Spec) -> Result<Simple> {
|
||||
let mut default_sink_name = get_pulse_default_sink()?;
|
||||
default_sink_name.push_str(".monitor");
|
||||
|
||||
let simple = Simple::new(
|
||||
None,
|
||||
"piano_thingy",
|
||||
Direction::Record,
|
||||
Some(&default_sink_name),
|
||||
"sample_yoinker",
|
||||
&spec,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(simple)
|
||||
}
|
||||
71
rust_native_module/src/beat_tracking/audio/dsp.rs
Normal file
71
rust_native_module/src/beat_tracking/audio/dsp.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
Taken from Till's magic Arduino Sketch
|
||||
*/
|
||||
pub trait ZTransformFilter {
|
||||
fn process(&mut self, sample: f32) -> f32;
|
||||
}
|
||||
|
||||
// 20 - 200Hz Single Pole Bandpass IIR Filter
|
||||
#[derive(Default)]
|
||||
pub struct BassFilter {
|
||||
xv: [f32; 3],
|
||||
yv: [f32; 3],
|
||||
}
|
||||
|
||||
impl ZTransformFilter for BassFilter {
|
||||
fn process(&mut self, sample: f32) -> f32 {
|
||||
self.xv[0] = self.xv[1];
|
||||
self.xv[1] = self.xv[2];
|
||||
self.xv[2] = sample / 3.0f32;
|
||||
|
||||
self.yv[0] = self.yv[1];
|
||||
self.yv[1] = self.yv[2];
|
||||
self.yv[2] = (self.xv[2] - self.xv[0])
|
||||
+ (-0.7960060012f32 * self.yv[0])
|
||||
+ (1.7903124146f32 * self.yv[1]);
|
||||
|
||||
self.yv[2]
|
||||
}
|
||||
}
|
||||
|
||||
// 10Hz Single Pole Lowpass IIR Filter
|
||||
#[derive(Default)]
|
||||
pub struct EnvelopeFilter {
|
||||
xv: [f32; 2],
|
||||
yv: [f32; 2],
|
||||
}
|
||||
|
||||
impl ZTransformFilter for EnvelopeFilter {
|
||||
fn process(&mut self, sample: f32) -> f32 {
|
||||
self.xv[0] = self.xv[1];
|
||||
self.xv[1] = sample / 50.0f32;
|
||||
|
||||
self.yv[0] = self.yv[1];
|
||||
self.yv[1] = (self.xv[0] + self.xv[1]) + (0.9875119299f32 * self.yv[0]);
|
||||
|
||||
self.yv[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 1.7 - 3.0Hz Single Pole Bandpass IIR Filter
|
||||
#[derive(Default)]
|
||||
pub struct BeatFilter {
|
||||
xv: [f32; 3],
|
||||
yv: [f32; 3],
|
||||
}
|
||||
|
||||
impl ZTransformFilter for BeatFilter {
|
||||
fn process(&mut self, sample: f32) -> f32 {
|
||||
self.xv[0] = self.xv[1];
|
||||
self.xv[1] = self.xv[2];
|
||||
self.xv[2] = sample / 2.7f32;
|
||||
|
||||
self.yv[0] = self.yv[1];
|
||||
self.yv[1] = self.yv[2];
|
||||
self.yv[2] = (self.xv[2] - self.xv[0])
|
||||
+ (-0.7169861741f32 * self.yv[0])
|
||||
+ (1.4453653501f32 * self.yv[1]);
|
||||
|
||||
self.yv[2]
|
||||
}
|
||||
}
|
||||
141
rust_native_module/src/beat_tracking/audio/mod.rs
Normal file
141
rust_native_module/src/beat_tracking/audio/mod.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use ringbuffer::{ConstGenericRingBuffer, RingBufferExt, RingBufferWrite};
|
||||
use std::{
|
||||
sync::{Arc, Mutex},
|
||||
thread::{self, JoinHandle},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use pulse::sample::{Format, Spec};
|
||||
|
||||
use self::dsp::ZTransformFilter;
|
||||
|
||||
mod capture;
|
||||
mod dsp;
|
||||
|
||||
const SAMPLE_RATE: usize = 5000;
|
||||
const PULSE_UPDATES_PER_SECOND: usize = 50;
|
||||
const BUFFER_SIZE: usize = SAMPLE_RATE / PULSE_UPDATES_PER_SECOND;
|
||||
|
||||
const POINTS_PER_SECOND: usize = SAMPLE_RATE / 200;
|
||||
pub const MILLIS_PER_POINT: usize = 1000 / POINTS_PER_SECOND;
|
||||
const POINT_MIN_COUNT: usize = 2 * POINTS_PER_SECOND;
|
||||
const POINT_BUFFER_SIZE: usize = POINT_MIN_COUNT.next_power_of_two();
|
||||
|
||||
pub struct AudioCaptureThread {
|
||||
join_handle: Option<JoinHandle<Result<()>>>,
|
||||
shared_state: Arc<Mutex<SharedState>>,
|
||||
}
|
||||
|
||||
impl AudioCaptureThread {
|
||||
pub fn new() -> Self {
|
||||
let shared_state = Arc::new(Mutex::new(SharedState::new()));
|
||||
let join_handle = {
|
||||
let shared_state = shared_state.clone();
|
||||
Some(thread::spawn(move || audio_capture_thread(shared_state)))
|
||||
};
|
||||
|
||||
Self {
|
||||
shared_state,
|
||||
join_handle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop(&mut self) -> Result<()> {
|
||||
if self.join_handle.is_none() {
|
||||
return Err(anyhow!(
|
||||
"join_handle was none, was stop called multiple times?"
|
||||
));
|
||||
} else {
|
||||
let handle = self.join_handle.take().unwrap();
|
||||
|
||||
handle.join().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_state(&self) -> (Vec<f32>, Instant) {
|
||||
let state = self.shared_state.lock().unwrap();
|
||||
(state.point_buf.to_vec(), state.last_update)
|
||||
}
|
||||
|
||||
pub fn get_points(&self) -> Vec<f32> {
|
||||
let state = self.shared_state.lock().unwrap();
|
||||
state.point_buf.to_vec()
|
||||
}
|
||||
|
||||
pub fn get_last_update(&self) -> Instant {
|
||||
let state = self.shared_state.lock().unwrap();
|
||||
state.last_update
|
||||
}
|
||||
}
|
||||
|
||||
struct SharedState {
|
||||
running: bool,
|
||||
point_buf: ConstGenericRingBuffer<f32, POINT_BUFFER_SIZE>,
|
||||
last_update: Instant,
|
||||
}
|
||||
|
||||
impl SharedState {
|
||||
pub fn new() -> Self {
|
||||
let mut point_buf: ConstGenericRingBuffer<f32, POINT_BUFFER_SIZE> = Default::default();
|
||||
point_buf.fill_default();
|
||||
|
||||
Self {
|
||||
running: true,
|
||||
point_buf,
|
||||
last_update: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn audio_capture_thread(state: Arc<Mutex<SharedState>>) -> Result<()> {
|
||||
let spec = Spec {
|
||||
format: Format::F32le,
|
||||
rate: SAMPLE_RATE as u32,
|
||||
channels: 1,
|
||||
};
|
||||
|
||||
let reader = capture::get_audio_reader(&spec)?;
|
||||
let mut buffer = [0u8; 4 * BUFFER_SIZE];
|
||||
|
||||
let mut bass_filter = dsp::BassFilter::default();
|
||||
let mut envelope_filter = dsp::EnvelopeFilter::default();
|
||||
let mut beat_filter = dsp::BeatFilter::default();
|
||||
let mut j = 0;
|
||||
|
||||
loop {
|
||||
{
|
||||
if state.lock().unwrap().running == false {
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
reader.read(&mut buffer)?;
|
||||
|
||||
for i in 0..BUFFER_SIZE {
|
||||
let mut float_bytes = [0u8; 4];
|
||||
float_bytes.copy_from_slice(&buffer[4 * i..4 * i + 4]);
|
||||
|
||||
j += 1;
|
||||
let sample = f32::from_le_bytes(float_bytes);
|
||||
let mut value = bass_filter.process(sample);
|
||||
|
||||
if value < 0f32 {
|
||||
value = -value;
|
||||
}
|
||||
|
||||
let envelope = envelope_filter.process(value);
|
||||
|
||||
if j == 200 {
|
||||
let beat = beat_filter.process(envelope);
|
||||
|
||||
let mut state = state.lock().unwrap();
|
||||
|
||||
state.point_buf.push(beat);
|
||||
state.last_update = Instant::now();
|
||||
|
||||
j = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{
|
||||
ops::RangeInclusive,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
pub struct Metronome {
|
||||
taps: Vec<Instant>,
|
||||
@@ -34,14 +37,35 @@ impl Metronome {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn current_beat_progress(&self) -> Option<f64> {
|
||||
pub fn get_beat_progress(&self, now: Instant) -> Option<f64> {
|
||||
if self.beat_interval.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let relative_millis = (now - *self.taps.last().unwrap()).as_millis();
|
||||
|
||||
Some(relative_millis as f64 / self.beat_interval.unwrap() as f64)
|
||||
}
|
||||
|
||||
pub fn get_beats(&self, from: Instant, to: Instant) -> Vec<(u64, Instant)> {
|
||||
if self.beat_interval.is_none() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let beat_interval = self.beat_interval.unwrap();
|
||||
|
||||
let last = *self.taps.last().unwrap();
|
||||
|
||||
let relative_from = (from - last).as_millis() as f64 / beat_interval as f64;
|
||||
let relative_to = (to - last).as_millis() as f64 / beat_interval as f64;
|
||||
|
||||
(relative_from.ceil() as u64..=relative_to.floor() as u64)
|
||||
.map(|beat_number| {
|
||||
(
|
||||
beat_number,
|
||||
last + Duration::from_millis((beat_number as u128 * beat_interval) as u64),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
mod audio;
|
||||
mod metronome;
|
||||
mod tracker;
|
||||
|
||||
use std::{cell::RefCell, time::Duration};
|
||||
|
||||
use metronome::Metronome;
|
||||
use tracker::BeatTracker;
|
||||
|
||||
use neon::prelude::*;
|
||||
|
||||
type BoxedTracker = JsBox<RefCell<Metronome>>;
|
||||
impl Finalize for Metronome {}
|
||||
type BoxedTracker = JsBox<RefCell<BeatTracker>>;
|
||||
impl Finalize for BeatTracker {}
|
||||
|
||||
pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult<JsObject> {
|
||||
let obj = cx.empty_object();
|
||||
let value_obj = cx.empty_object();
|
||||
|
||||
let boxed_tracker = cx.boxed(RefCell::new(Metronome::new(Duration::from_secs(2))));
|
||||
let boxed_tracker = cx.boxed(RefCell::new(BeatTracker::new(Duration::from_secs(2))));
|
||||
value_obj.set(&mut cx, "_rust_ptr", boxed_tracker)?;
|
||||
|
||||
let tap_function = JsFunction::new(&mut cx, tap)?;
|
||||
@@ -22,6 +24,12 @@ pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult<JsObject> {
|
||||
let get_progress_function = JsFunction::new(&mut cx, get_progress)?;
|
||||
value_obj.set(&mut cx, "getProgress", get_progress_function)?;
|
||||
|
||||
let close_function = JsFunction::new(&mut cx, close)?;
|
||||
value_obj.set(&mut cx, "close", close_function)?;
|
||||
|
||||
let get_graph_points_function = JsFunction::new(&mut cx, get_graph_points)?;
|
||||
value_obj.set(&mut cx, "getGraphPoints", get_graph_points_function)?;
|
||||
|
||||
obj.set(&mut cx, "value", value_obj)?;
|
||||
|
||||
let success_string = cx.string("success".to_string());
|
||||
@@ -68,3 +76,57 @@ pub fn get_progress(mut cx: FunctionContext) -> JsResult<JsObject> {
|
||||
|
||||
Ok(obj)
|
||||
}
|
||||
|
||||
fn get_graph_points(mut cx: FunctionContext) -> JsResult<JsObject> {
|
||||
let this = cx.this();
|
||||
|
||||
let boxed_tracker = this
|
||||
.get(&mut cx, "_rust_ptr")?
|
||||
.downcast_or_throw::<BoxedTracker, _>(&mut cx)?;
|
||||
|
||||
let obj = cx.empty_object();
|
||||
|
||||
let (bass_filtered, auto_correlated) = boxed_tracker.borrow().get_graph_points();
|
||||
|
||||
let arr_bf = cx.empty_array();
|
||||
for (i, f) in bass_filtered.iter().enumerate() {
|
||||
let num = cx.number(*f);
|
||||
arr_bf.set(&mut cx, i as u32, num)?;
|
||||
}
|
||||
obj.set(&mut cx, "bassFiltered", arr_bf)?;
|
||||
|
||||
let arr_ac = cx.empty_array();
|
||||
for (i, f) in auto_correlated.iter().enumerate() {
|
||||
let num = cx.number(*f);
|
||||
arr_ac.set(&mut cx, i as u32, num)?;
|
||||
}
|
||||
obj.set(&mut cx, "autoCorrelated", arr_ac)?;
|
||||
|
||||
Ok(obj)
|
||||
}
|
||||
|
||||
pub fn close(mut cx: FunctionContext) -> JsResult<JsObject> {
|
||||
let this = cx.this();
|
||||
|
||||
let boxed_tracker = this
|
||||
.get(&mut cx, "_rust_ptr")?
|
||||
.downcast_or_throw::<BoxedTracker, _>(&mut cx)?;
|
||||
|
||||
let obj = cx.empty_object();
|
||||
let success_string;
|
||||
|
||||
match boxed_tracker.borrow_mut().stop() {
|
||||
Ok(()) => {
|
||||
success_string = cx.string("success");
|
||||
}
|
||||
Err(e) => {
|
||||
success_string = cx.string("error");
|
||||
|
||||
let error_message = cx.string(e.to_string());
|
||||
obj.set(&mut cx, "message", error_message)?;
|
||||
}
|
||||
}
|
||||
obj.set(&mut cx, "type", success_string)?;
|
||||
|
||||
Ok(obj)
|
||||
}
|
||||
|
||||
167
rust_native_module/src/beat_tracking/tracker.rs
Normal file
167
rust_native_module/src/beat_tracking/tracker.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
use anyhow::Result;
|
||||
use num::Complex;
|
||||
use rustfft::FftPlanner;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use super::audio::{AudioCaptureThread, MILLIS_PER_POINT};
|
||||
use super::metronome::Metronome;
|
||||
|
||||
pub struct TrackerConfig {
|
||||
ac_threshold: f32,
|
||||
zero_crossing_beat_delay: i64,
|
||||
}
|
||||
|
||||
pub struct BeatTracker {
|
||||
metronome: Metronome,
|
||||
config: TrackerConfig,
|
||||
audio_capture_thread: AudioCaptureThread,
|
||||
}
|
||||
|
||||
impl BeatTracker {
|
||||
pub fn new(metronome_timeout: Duration) -> Self {
|
||||
Self {
|
||||
metronome: Metronome::new(metronome_timeout),
|
||||
config: TrackerConfig {
|
||||
ac_threshold: 1000.0,
|
||||
zero_crossing_beat_delay: 0,
|
||||
},
|
||||
audio_capture_thread: AudioCaptureThread::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tap(&mut self) {
|
||||
self.metronome.tap();
|
||||
}
|
||||
|
||||
pub fn get_graph_points(&self) -> (Vec<f32>, Vec<f32>) {
|
||||
let points = self.audio_capture_thread.get_points();
|
||||
|
||||
let mut autocorrelation = points
|
||||
.iter()
|
||||
.map(|point| Complex {
|
||||
re: *point,
|
||||
im: 0.0,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let n = autocorrelation.len();
|
||||
|
||||
for _ in 0..n {
|
||||
autocorrelation.push(Complex { re: 0.0, im: 0.0 });
|
||||
}
|
||||
|
||||
let mut planner = FftPlanner::new();
|
||||
let fft = planner.plan_fft_forward(2 * n);
|
||||
let ifft = planner.plan_fft_inverse(2 * n);
|
||||
|
||||
fft.process(&mut autocorrelation);
|
||||
|
||||
for f in autocorrelation.iter_mut() {
|
||||
*f = Complex {
|
||||
re: f.norm_sqr(),
|
||||
im: 0.0,
|
||||
};
|
||||
}
|
||||
|
||||
ifft.process(&mut autocorrelation);
|
||||
|
||||
let mut autocorrelation: Vec<f32> = autocorrelation[..n].iter().map(|x| x.re).collect();
|
||||
|
||||
// remove first peak
|
||||
|
||||
let mut j = 0;
|
||||
let i = loop {
|
||||
j += 1;
|
||||
|
||||
if j == n {
|
||||
break None;
|
||||
}
|
||||
|
||||
if autocorrelation[j - 1] <= autocorrelation[j] {
|
||||
break Some(j);
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(i) = i {
|
||||
let first_increase = autocorrelation[i];
|
||||
|
||||
for f in autocorrelation[..i].iter_mut() {
|
||||
*f = first_increase;
|
||||
}
|
||||
}
|
||||
|
||||
(points, autocorrelation)
|
||||
}
|
||||
|
||||
pub fn current_beat_progress(&self) -> Option<f64> {
|
||||
let (points, autocorrelation) = self.get_graph_points();
|
||||
|
||||
if autocorrelation
|
||||
.iter()
|
||||
.reduce(|a, b| if a > b { a } else { b })
|
||||
.unwrap()
|
||||
< &self.config.ac_threshold
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut j = 0;
|
||||
let i = loop {
|
||||
j += 1;
|
||||
|
||||
if j == autocorrelation.len() {
|
||||
break None;
|
||||
}
|
||||
|
||||
if autocorrelation[j - 1] > autocorrelation[j] {
|
||||
break Some(j);
|
||||
}
|
||||
};
|
||||
|
||||
if i.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let period_length = i.unwrap();
|
||||
|
||||
// println!("predicted period length: {}", period_length);
|
||||
|
||||
// find zero crossing
|
||||
|
||||
let mut crossing = None;
|
||||
|
||||
for i in 1..points.len() {
|
||||
if points[i-1].signum() < 0.0 && points[i].signum() > 0.0 {
|
||||
crossing = Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
if crossing.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let crossing = crossing.unwrap();
|
||||
// println!("index of last positive zero crossing: {}", crossing);
|
||||
|
||||
let last_update_timestamp = self.audio_capture_thread.get_last_update();
|
||||
let now = Instant::now();
|
||||
|
||||
let mut dt = (now - last_update_timestamp).as_millis() as i64;
|
||||
dt += ((points.len() - crossing) * MILLIS_PER_POINT) as i64;
|
||||
dt += self.config.zero_crossing_beat_delay;
|
||||
|
||||
let dt = dt as f64;
|
||||
|
||||
let period_millis = (period_length * MILLIS_PER_POINT) as f64;
|
||||
|
||||
let relative_time = (dt % period_millis) / period_millis;
|
||||
|
||||
println!("relative time: {}", relative_time);
|
||||
|
||||
Some(relative_time)
|
||||
}
|
||||
|
||||
pub fn stop(&mut self) -> Result<()> {
|
||||
self.audio_capture_thread.stop()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user