use anyhow::Result; use num::Complex; use rustfft::FftPlanner; use std::time::{Duration, Instant}; use crate::beat_tracking::audio::POINT_BUFFER_SIZE; use super::audio::{AudioCaptureThread, MILLIS_PER_POINT}; use super::metronome::Metronome; pub enum TrackerMode { AUTO, MANUAL, } pub struct TrackerConfig { pub mode: TrackerMode, pub ac_threshold: f32, pub zero_crossing_beat_delay: i64, } pub struct BeatTracker { metronome: Metronome, pub config: TrackerConfig, audio_capture_thread: AudioCaptureThread, current_beats: Option<(Instant, Instant)>, beat_count: usize, } impl BeatTracker { pub fn new(metronome_timeout: Duration) -> Self { Self { metronome: Metronome::new(metronome_timeout), config: TrackerConfig { mode: TrackerMode::AUTO, ac_threshold: 1000.0, zero_crossing_beat_delay: 0, }, audio_capture_thread: AudioCaptureThread::new(), current_beats: None, beat_count: 0, } } pub fn tap(&mut self) { self.metronome.tap(); } pub fn get_graph_points(&self) -> (Vec, Vec) { let points = self.audio_capture_thread.get_points(); let mut autocorrelation = points .iter() .map(|point| Complex { re: *point, im: 0.0, }) .collect::>(); let n = autocorrelation.len(); for _ in 0..n { autocorrelation.push(Complex { re: 0.0, im: 0.0 }); } let mut planner = FftPlanner::new(); let fft = planner.plan_fft_forward(2 * n); let ifft = planner.plan_fft_inverse(2 * n); fft.process(&mut autocorrelation); for f in autocorrelation.iter_mut() { *f = Complex { re: f.norm_sqr(), im: 0.0, }; } ifft.process(&mut autocorrelation); let mut autocorrelation: Vec = autocorrelation[..n].iter().map(|x| x.re).collect(); // remove first peak let mut j = 0; let i = loop { j += 1; if j == n { break None; } if autocorrelation[j - 1] <= autocorrelation[j] { break Some(j); } }; if let Some(i) = i { let first_increase = autocorrelation[i]; for f in autocorrelation[..i].iter_mut() { *f = first_increase; } } (points, autocorrelation) } // (period_length [ms], crossing [point index, timing relative to last_update]) fn get_correlation_data(&self) -> Option<(usize, usize)> { let (points, autocorrelation) = self.get_graph_points(); if autocorrelation .iter() .reduce(|a, b| if a > b { a } else { b }) .unwrap() < &self.config.ac_threshold { return None; } let mut j = 0; let i = loop { j += 1; if j == autocorrelation.len() { break None; } if autocorrelation[j - 1] > autocorrelation[j] { break Some(j); } }; if i.is_none() { return None; } let period_length = i.unwrap(); // find zero crossing let mut crossing = None; for i in 1..points.len() { if points[i - 1].signum() < 0.0 && points[i].signum() > 0.0 { crossing = Some(i); } } if crossing.is_none() { return None; } let crossing = crossing.unwrap(); Some((period_length, crossing)) } pub fn current_beat_progress(&mut self) -> Option { let now = Instant::now(); match self.config.mode { TrackerMode::AUTO => { // println!("reee"); let should_update = { match self.current_beats { None => true, Some((_, next)) if now > next => { self.beat_count += 1; true } _ => false, } }; if should_update { if let Some((period_length, crossing)) = self.get_correlation_data() { let last_update_timestamp = self.audio_capture_thread.get_last_update(); let mut dt = (now - last_update_timestamp).as_millis() as i64; dt += ((POINT_BUFFER_SIZE - crossing) * MILLIS_PER_POINT) as i64; let mut prev = now - Duration::from_millis(dt as u64); let period_millis = (period_length * MILLIS_PER_POINT) as u64; let period = Duration::from_millis(period_millis); let mut next = prev + period; if let Some((_, old_next)) = self.current_beats { while next < old_next + Duration::from_millis(period_millis / 2) { prev += period; next += period; } } self.current_beats = Some((prev, next)); } else { self.beat_count = 0; return None; } } let (prev, next) = self.current_beats.unwrap(); let mut relative_millis = if prev < now { (now - prev).as_millis() as f64 } else { -1.0 * (prev - now).as_millis() as f64 }; relative_millis += self.config.zero_crossing_beat_delay as f64; let fractional = relative_millis / (next - prev).as_millis() as f64; Some(self.beat_count as f64 + fractional) } TrackerMode::MANUAL => self.metronome.get_beat_progress(now), } } pub fn stop(&mut self) -> Result<()> { self.audio_capture_thread.stop() } }