From bf62ceb98f2ed60b71bbcb5625311d452c2e3396 Mon Sep 17 00:00:00 2001 From: Kai Vogelgesang Date: Fri, 12 Nov 2021 16:37:04 +0100 Subject: [PATCH] Improve Beat Detection --- boilerbloat/src/patterns/chaser.ts | 19 +-- .../src/beat_tracking/audio/mod.rs | 2 +- rust_native_module/src/beat_tracking/mod.rs | 2 +- .../src/beat_tracking/tracker.rs | 156 +++++++++++------- 4 files changed, 104 insertions(+), 75 deletions(-) diff --git a/boilerbloat/src/patterns/chaser.ts b/boilerbloat/src/patterns/chaser.ts index c5d926f..4f73eb5 100644 --- a/boilerbloat/src/patterns/chaser.ts +++ b/boilerbloat/src/patterns/chaser.ts @@ -3,30 +3,15 @@ import { Pattern, PatternOutput, Time } from './proto'; export class ChaserPattern implements Pattern { - lastBeat: number; - lastTime: number; - - constructor() { - this.lastBeat = 0; - this.lastTime = 0; - } - render(time: Time): PatternOutput { if (time.beatRelative === null) { - this.lastBeat = 0; return null; } let t = time.beatRelative; - if (t < this.lastTime) { - this.lastBeat += 1; - } - - this.lastTime = t; - - let head_number = this.lastBeat % 4; + let head_number = Math.floor(t % 4); let template: MovingHeadState = { startAddress: 0, @@ -44,7 +29,7 @@ export class ChaserPattern implements Pattern { let result = []; for (let [i, startAddress] of [1, 15, 29, 43].entries()) { - result[i] = { ...template}; + result[i] = { ...template }; result[i].startAddress = startAddress; if (i === head_number) { diff --git a/rust_native_module/src/beat_tracking/audio/mod.rs b/rust_native_module/src/beat_tracking/audio/mod.rs index d97cc8c..78f17ce 100644 --- a/rust_native_module/src/beat_tracking/audio/mod.rs +++ b/rust_native_module/src/beat_tracking/audio/mod.rs @@ -20,7 +20,7 @@ const BUFFER_SIZE: usize = SAMPLE_RATE / PULSE_UPDATES_PER_SECOND; const POINTS_PER_SECOND: usize = SAMPLE_RATE / 200; pub const MILLIS_PER_POINT: usize = 1000 / POINTS_PER_SECOND; const POINT_MIN_COUNT: usize = 2 * POINTS_PER_SECOND; -const POINT_BUFFER_SIZE: usize = POINT_MIN_COUNT.next_power_of_two(); +pub const POINT_BUFFER_SIZE: usize = POINT_MIN_COUNT.next_power_of_two(); pub struct AudioCaptureThread { join_handle: Option>>, diff --git a/rust_native_module/src/beat_tracking/mod.rs b/rust_native_module/src/beat_tracking/mod.rs index a8a9b4f..7ad9a4a 100644 --- a/rust_native_module/src/beat_tracking/mod.rs +++ b/rust_native_module/src/beat_tracking/mod.rs @@ -105,7 +105,7 @@ pub fn get_progress(mut cx: FunctionContext) -> JsResult { let obj = cx.empty_object(); let type_string; - match boxed_tracker.borrow().current_beat_progress() { + match boxed_tracker.borrow_mut().current_beat_progress() { Some(progress) => { type_string = cx.string("some".to_string()); diff --git a/rust_native_module/src/beat_tracking/tracker.rs b/rust_native_module/src/beat_tracking/tracker.rs index 4886fed..265fd6a 100644 --- a/rust_native_module/src/beat_tracking/tracker.rs +++ b/rust_native_module/src/beat_tracking/tracker.rs @@ -3,6 +3,8 @@ use num::Complex; use rustfft::FftPlanner; use std::time::{Duration, Instant}; +use crate::beat_tracking::audio::POINT_BUFFER_SIZE; + use super::audio::{AudioCaptureThread, MILLIS_PER_POINT}; use super::metronome::Metronome; @@ -22,6 +24,9 @@ pub struct BeatTracker { metronome: Metronome, pub config: TrackerConfig, audio_capture_thread: AudioCaptureThread, + + current_beats: Option<(Instant, Instant)>, + beat_count: usize, } impl BeatTracker { @@ -34,6 +39,9 @@ impl BeatTracker { zero_crossing_beat_delay: 0, }, audio_capture_thread: AudioCaptureThread::new(), + + current_beats: None, + beat_count: 0, } } @@ -101,80 +109,116 @@ impl BeatTracker { (points, autocorrelation) } - pub fn current_beat_progress(&self) -> Option { + // (period_length [ms], crossing [point index, timing relative to last_update]) + fn get_correlation_data(&self) -> Option<(usize, usize)> { + let (points, autocorrelation) = self.get_graph_points(); + + if autocorrelation + .iter() + .reduce(|a, b| if a > b { a } else { b }) + .unwrap() + < &self.config.ac_threshold + { + return None; + } + + let mut j = 0; + let i = loop { + j += 1; + + if j == autocorrelation.len() { + break None; + } + + if autocorrelation[j - 1] > autocorrelation[j] { + break Some(j); + } + }; + + if i.is_none() { + return None; + } + + let period_length = i.unwrap(); + + // find zero crossing + + let mut crossing = None; + + for i in 1..points.len() { + if points[i - 1].signum() < 0.0 && points[i].signum() > 0.0 { + crossing = Some(i); + } + } + + if crossing.is_none() { + return None; + } + + let crossing = crossing.unwrap(); + + Some((period_length, crossing)) + } + + pub fn current_beat_progress(&mut self) -> Option { + let now = Instant::now(); + match self.config.mode { TrackerMode::AUTO => { - let (points, autocorrelation) = self.get_graph_points(); + // println!("reee"); - if autocorrelation - .iter() - .reduce(|a, b| if a > b { a } else { b }) - .unwrap() - < &self.config.ac_threshold - { - return None; - } - - let mut j = 0; - let i = loop { - j += 1; - - if j == autocorrelation.len() { - break None; - } - - if autocorrelation[j - 1] > autocorrelation[j] { - break Some(j); + let should_update = { + match self.current_beats { + None => true, + Some((_, next)) if now > next => { + self.beat_count += 1; + true + } + _ => false, } }; - if i.is_none() { - return None; - } + if should_update { + if let Some((period_length, crossing)) = self.get_correlation_data() { + let last_update_timestamp = self.audio_capture_thread.get_last_update(); - let period_length = i.unwrap(); + let mut dt = (now - last_update_timestamp).as_millis() as i64; + dt += ((POINT_BUFFER_SIZE - crossing) * MILLIS_PER_POINT) as i64; - // println!("predicted period length: {}", period_length); + let mut prev = now - Duration::from_millis(dt as u64); - // find zero crossing + let period_millis = (period_length * MILLIS_PER_POINT) as u64; + + let period = Duration::from_millis(period_millis); - let mut crossing = None; + let mut next = prev + period; + + if let Some((_, old_next)) = self.current_beats { + while next < old_next + Duration::from_millis(period_millis / 2) { + prev += period; + next += period; + } + } - for i in 1..points.len() { - if points[i - 1].signum() < 0.0 && points[i].signum() > 0.0 { - crossing = Some(i); + self.current_beats = Some((prev, next)); + } else { + self.beat_count = 0; + return None; } } - if crossing.is_none() { - return None; - } + let (prev, next) = self.current_beats.unwrap(); - let crossing = crossing.unwrap(); - // println!("index of last positive zero crossing: {}", crossing); + let fractional = if prev < now { + (now - prev).as_millis() as f64 / (next - prev).as_millis() as f64 + } else { + -1.0 * (prev - now).as_millis() as f64 / (next - prev).as_millis() as f64 + }; - let last_update_timestamp = self.audio_capture_thread.get_last_update(); - let now = Instant::now(); - - let mut dt = (now - last_update_timestamp).as_millis() as i64; - dt += ((points.len() - crossing) * MILLIS_PER_POINT) as i64; - dt += self.config.zero_crossing_beat_delay; - - let dt = dt as f64; - - let period_millis = (period_length * MILLIS_PER_POINT) as f64; - - let relative_time = (dt % period_millis) / period_millis; - - println!("relative time: {}", relative_time); - - Some(relative_time) + Some(self.beat_count as f64 + fractional) } - TrackerMode::MANUAL => { - let now = Instant::now(); - self.metronome.get_beat_progress(now) - } + TrackerMode::MANUAL => self.metronome.get_beat_progress(now), } }