233 lines
6.2 KiB
Rust
233 lines
6.2 KiB
Rust
use anyhow::Result;
|
|
use num::Complex;
|
|
use rustfft::FftPlanner;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use crate::beat_tracking::audio::POINT_BUFFER_SIZE;
|
|
|
|
use super::audio::{AudioCaptureThread, MILLIS_PER_POINT};
|
|
use super::metronome::Metronome;
|
|
|
|
pub enum TrackerMode {
|
|
AUTO,
|
|
MANUAL,
|
|
}
|
|
|
|
pub struct TrackerConfig {
|
|
pub mode: TrackerMode,
|
|
|
|
pub ac_threshold: f32,
|
|
pub zero_crossing_beat_delay: i64,
|
|
}
|
|
|
|
pub struct BeatTracker {
|
|
metronome: Metronome,
|
|
pub config: TrackerConfig,
|
|
audio_capture_thread: AudioCaptureThread,
|
|
|
|
current_beats: Option<(Instant, Instant)>,
|
|
beat_count: usize,
|
|
}
|
|
|
|
impl BeatTracker {
|
|
pub fn new(metronome_timeout: Duration) -> Self {
|
|
Self {
|
|
metronome: Metronome::new(metronome_timeout),
|
|
config: TrackerConfig {
|
|
mode: TrackerMode::AUTO,
|
|
ac_threshold: 1000.0,
|
|
zero_crossing_beat_delay: 0,
|
|
},
|
|
audio_capture_thread: AudioCaptureThread::new(),
|
|
|
|
current_beats: None,
|
|
beat_count: 0,
|
|
}
|
|
}
|
|
|
|
pub fn tap(&mut self) {
|
|
self.metronome.tap();
|
|
}
|
|
|
|
pub fn get_graph_points(&self) -> (Vec<f32>, Vec<f32>) {
|
|
let points = self.audio_capture_thread.get_points();
|
|
|
|
let mut autocorrelation = points
|
|
.iter()
|
|
.map(|point| Complex {
|
|
re: *point,
|
|
im: 0.0,
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let n = autocorrelation.len();
|
|
|
|
for _ in 0..n {
|
|
autocorrelation.push(Complex { re: 0.0, im: 0.0 });
|
|
}
|
|
|
|
let mut planner = FftPlanner::new();
|
|
let fft = planner.plan_fft_forward(2 * n);
|
|
let ifft = planner.plan_fft_inverse(2 * n);
|
|
|
|
fft.process(&mut autocorrelation);
|
|
|
|
for f in autocorrelation.iter_mut() {
|
|
*f = Complex {
|
|
re: f.norm_sqr(),
|
|
im: 0.0,
|
|
};
|
|
}
|
|
|
|
ifft.process(&mut autocorrelation);
|
|
|
|
let mut autocorrelation: Vec<f32> = autocorrelation[..n].iter().map(|x| x.re).collect();
|
|
|
|
// remove first peak
|
|
|
|
let mut j = 0;
|
|
let i = loop {
|
|
j += 1;
|
|
|
|
if j == n {
|
|
break None;
|
|
}
|
|
|
|
if autocorrelation[j - 1] <= autocorrelation[j] {
|
|
break Some(j);
|
|
}
|
|
};
|
|
|
|
if let Some(i) = i {
|
|
let first_increase = autocorrelation[i];
|
|
|
|
for f in autocorrelation[..i].iter_mut() {
|
|
*f = first_increase;
|
|
}
|
|
}
|
|
|
|
(points, autocorrelation)
|
|
}
|
|
|
|
// (period_length [ms], crossing [point index, timing relative to last_update])
|
|
fn get_correlation_data(&self) -> Option<(usize, usize)> {
|
|
let (points, autocorrelation) = self.get_graph_points();
|
|
|
|
if autocorrelation
|
|
.iter()
|
|
.reduce(|a, b| if a > b { a } else { b })
|
|
.unwrap()
|
|
< &self.config.ac_threshold
|
|
{
|
|
return None;
|
|
}
|
|
|
|
let mut j = 0;
|
|
let i = loop {
|
|
j += 1;
|
|
|
|
if j == autocorrelation.len() {
|
|
break None;
|
|
}
|
|
|
|
if autocorrelation[j - 1] > autocorrelation[j] {
|
|
break Some(j);
|
|
}
|
|
};
|
|
|
|
if i.is_none() {
|
|
return None;
|
|
}
|
|
|
|
let period_length = i.unwrap();
|
|
|
|
// find zero crossing
|
|
|
|
let mut crossing = None;
|
|
|
|
for i in 1..points.len() {
|
|
if points[i - 1].signum() < 0.0 && points[i].signum() > 0.0 {
|
|
crossing = Some(i);
|
|
}
|
|
}
|
|
|
|
if crossing.is_none() {
|
|
return None;
|
|
}
|
|
|
|
let crossing = crossing.unwrap();
|
|
|
|
Some((period_length, crossing))
|
|
}
|
|
|
|
pub fn current_beat_progress(&mut self) -> Option<f64> {
|
|
let now = Instant::now();
|
|
|
|
match self.config.mode {
|
|
TrackerMode::AUTO => {
|
|
// println!("reee");
|
|
|
|
let should_update = {
|
|
match self.current_beats {
|
|
None => true,
|
|
Some((_, next)) if now > next => {
|
|
self.beat_count += 1;
|
|
true
|
|
}
|
|
_ => false,
|
|
}
|
|
};
|
|
|
|
if should_update {
|
|
if let Some((period_length, crossing)) = self.get_correlation_data() {
|
|
let last_update_timestamp = self.audio_capture_thread.get_last_update();
|
|
|
|
let mut dt = (now - last_update_timestamp).as_millis() as i64;
|
|
dt += ((POINT_BUFFER_SIZE - crossing) * MILLIS_PER_POINT) as i64;
|
|
|
|
let mut prev = now - Duration::from_millis(dt as u64);
|
|
|
|
let period_millis = (period_length * MILLIS_PER_POINT) as u64;
|
|
|
|
let period = Duration::from_millis(period_millis);
|
|
|
|
let mut next = prev + period;
|
|
|
|
if let Some((_, old_next)) = self.current_beats {
|
|
while next < old_next + Duration::from_millis(period_millis / 2) {
|
|
prev += period;
|
|
next += period;
|
|
}
|
|
}
|
|
|
|
self.current_beats = Some((prev, next));
|
|
} else {
|
|
self.beat_count = 0;
|
|
return None;
|
|
}
|
|
}
|
|
|
|
let (prev, next) = self.current_beats.unwrap();
|
|
|
|
let mut relative_millis = if prev < now {
|
|
(now - prev).as_millis() as f64
|
|
} else {
|
|
-1.0 * (prev - now).as_millis() as f64
|
|
};
|
|
|
|
relative_millis += self.config.zero_crossing_beat_delay as f64;
|
|
|
|
let fractional = relative_millis / (next - prev).as_millis() as f64;
|
|
|
|
Some(self.beat_count as f64 + fractional)
|
|
}
|
|
|
|
TrackerMode::MANUAL => self.metronome.get_beat_progress(now),
|
|
}
|
|
}
|
|
|
|
pub fn stop(&mut self) -> Result<()> {
|
|
self.audio_capture_thread.stop()
|
|
}
|
|
}
|