diff --git a/rust_native_module/index.d.ts b/rust_native_module/index.d.ts index dd2973e..d533a43 100644 --- a/rust_native_module/index.d.ts +++ b/rust_native_module/index.d.ts @@ -30,12 +30,19 @@ declare module rust_native_module { close: () => Result, } + type TrackerConfig = { + mode: "auto" | "manual", + acThreshold: number, + zeroCrossingBeatDelay: number, + } + type GraphPoints = { bassFiltered: Array, autoCorrelated: Array, } type BeatTrackerHandle = { + setConfig: (config: TrackerConfig) => void, tap: () => void, getProgress: () => Option, stop: () => Result, diff --git a/rust_native_module/src/beat_tracking/audio/mod.rs b/rust_native_module/src/beat_tracking/audio/mod.rs index b5440b8..d97cc8c 100644 --- a/rust_native_module/src/beat_tracking/audio/mod.rs +++ b/rust_native_module/src/beat_tracking/audio/mod.rs @@ -52,11 +52,6 @@ impl AudioCaptureThread { handle.join().unwrap() } } - - pub fn read_state(&self) -> (Vec, Instant) { - let state = self.shared_state.lock().unwrap(); - (state.point_buf.to_vec(), state.last_update) - } pub fn get_points(&self) -> Vec { let state = self.shared_state.lock().unwrap(); diff --git a/rust_native_module/src/beat_tracking/metronome.rs b/rust_native_module/src/beat_tracking/metronome.rs index c4dc1f9..a34022a 100644 --- a/rust_native_module/src/beat_tracking/metronome.rs +++ b/rust_native_module/src/beat_tracking/metronome.rs @@ -1,5 +1,4 @@ use std::{ - ops::RangeInclusive, time::{Duration, Instant}, }; diff --git a/rust_native_module/src/beat_tracking/mod.rs b/rust_native_module/src/beat_tracking/mod.rs index 81a4082..a8a9b4f 100644 --- a/rust_native_module/src/beat_tracking/mod.rs +++ b/rust_native_module/src/beat_tracking/mod.rs @@ -8,6 +8,8 @@ use tracker::BeatTracker; use neon::prelude::*; +use self::tracker::TrackerMode; + type BoxedTracker = JsBox>; impl Finalize for BeatTracker {} @@ -18,6 +20,9 @@ pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult { let boxed_tracker = cx.boxed(RefCell::new(BeatTracker::new(Duration::from_secs(2)))); value_obj.set(&mut cx, "_rust_ptr", boxed_tracker)?; + let set_config_function = JsFunction::new(&mut cx, set_config)?; + value_obj.set(&mut cx, "setConfig", set_config_function)?; + let tap_function = JsFunction::new(&mut cx, tap)?; value_obj.set(&mut cx, "tap", tap_function)?; @@ -38,6 +43,46 @@ pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult { Ok(obj) } +pub fn set_config(mut cx: FunctionContext) -> JsResult { + let this = cx.this(); + let boxed_tracker = this + .get(&mut cx, "_rust_ptr")? + .downcast_or_throw::(&mut cx)?; + + let ref mut config = boxed_tracker.borrow_mut().config; + + let arg = cx.argument::(0)?; + + let mode = arg + .get(&mut cx, "mode")? + .downcast_or_throw::(&mut cx)? + .value(&mut cx); + + config.mode = match mode.as_str() { + "auto" => TrackerMode::AUTO, + "manual" => TrackerMode::MANUAL, + s => { + return cx.throw_error(format!("Invalid config mode: '{}'", s.to_string())); + } + }; + + let ac_threshold = arg + .get(&mut cx, "acThreshold")? + .downcast_or_throw::(&mut cx)? + .value(&mut cx); + + config.ac_threshold = ac_threshold as f32; + + let zero_crossing_beat_delay = arg + .get(&mut cx, "zeroCrossingBeatDelay")? + .downcast_or_throw::(&mut cx)? + .value(&mut cx); + + config.zero_crossing_beat_delay = zero_crossing_beat_delay as i64; + + Ok(cx.undefined()) +} + pub fn tap(mut cx: FunctionContext) -> JsResult { let this = cx.this(); diff --git a/rust_native_module/src/beat_tracking/tracker.rs b/rust_native_module/src/beat_tracking/tracker.rs index d4d7b1b..4886fed 100644 --- a/rust_native_module/src/beat_tracking/tracker.rs +++ b/rust_native_module/src/beat_tracking/tracker.rs @@ -6,14 +6,21 @@ use std::time::{Duration, Instant}; use super::audio::{AudioCaptureThread, MILLIS_PER_POINT}; use super::metronome::Metronome; +pub enum TrackerMode { + AUTO, + MANUAL, +} + pub struct TrackerConfig { - ac_threshold: f32, - zero_crossing_beat_delay: i64, + pub mode: TrackerMode, + + pub ac_threshold: f32, + pub zero_crossing_beat_delay: i64, } pub struct BeatTracker { metronome: Metronome, - config: TrackerConfig, + pub config: TrackerConfig, audio_capture_thread: AudioCaptureThread, } @@ -22,6 +29,7 @@ impl BeatTracker { Self { metronome: Metronome::new(metronome_timeout), config: TrackerConfig { + mode: TrackerMode::AUTO, ac_threshold: 1000.0, zero_crossing_beat_delay: 0, }, @@ -94,71 +102,80 @@ impl BeatTracker { } pub fn current_beat_progress(&self) -> Option { - let (points, autocorrelation) = self.get_graph_points(); + match self.config.mode { + TrackerMode::AUTO => { + let (points, autocorrelation) = self.get_graph_points(); - if autocorrelation - .iter() - .reduce(|a, b| if a > b { a } else { b }) - .unwrap() - < &self.config.ac_threshold - { - return None; - } + if autocorrelation + .iter() + .reduce(|a, b| if a > b { a } else { b }) + .unwrap() + < &self.config.ac_threshold + { + return None; + } - let mut j = 0; - let i = loop { - j += 1; + let mut j = 0; + let i = loop { + j += 1; - if j == autocorrelation.len() { - break None; + if j == autocorrelation.len() { + break None; + } + + if autocorrelation[j - 1] > autocorrelation[j] { + break Some(j); + } + }; + + if i.is_none() { + return None; + } + + let period_length = i.unwrap(); + + // println!("predicted period length: {}", period_length); + + // find zero crossing + + let mut crossing = None; + + for i in 1..points.len() { + if points[i - 1].signum() < 0.0 && points[i].signum() > 0.0 { + crossing = Some(i); + } + } + + if crossing.is_none() { + return None; + } + + let crossing = crossing.unwrap(); + // println!("index of last positive zero crossing: {}", crossing); + + let last_update_timestamp = self.audio_capture_thread.get_last_update(); + let now = Instant::now(); + + let mut dt = (now - last_update_timestamp).as_millis() as i64; + dt += ((points.len() - crossing) * MILLIS_PER_POINT) as i64; + dt += self.config.zero_crossing_beat_delay; + + let dt = dt as f64; + + let period_millis = (period_length * MILLIS_PER_POINT) as f64; + + let relative_time = (dt % period_millis) / period_millis; + + println!("relative time: {}", relative_time); + + Some(relative_time) } - if autocorrelation[j - 1] > autocorrelation[j] { - break Some(j); - } - }; - - if i.is_none() { - return None; - } - - let period_length = i.unwrap(); - - // println!("predicted period length: {}", period_length); - - // find zero crossing - - let mut crossing = None; - - for i in 1..points.len() { - if points[i-1].signum() < 0.0 && points[i].signum() > 0.0 { - crossing = Some(i); + TrackerMode::MANUAL => { + let now = Instant::now(); + self.metronome.get_beat_progress(now) } } - - if crossing.is_none() { - return None; - } - - let crossing = crossing.unwrap(); - // println!("index of last positive zero crossing: {}", crossing); - - let last_update_timestamp = self.audio_capture_thread.get_last_update(); - let now = Instant::now(); - - let mut dt = (now - last_update_timestamp).as_millis() as i64; - dt += ((points.len() - crossing) * MILLIS_PER_POINT) as i64; - dt += self.config.zero_crossing_beat_delay; - - let dt = dt as f64; - - let period_millis = (period_length * MILLIS_PER_POINT) as f64; - - let relative_time = (dt % period_millis) / period_millis; - - println!("relative time: {}", relative_time); - - Some(relative_time) } pub fn stop(&mut self) -> Result<()> {