partylights/rust_native_module/src/beat_tracking/tracker.rs
2021-11-13 00:22:24 +01:00

233 lines
6.2 KiB
Rust

use anyhow::Result;
use num::Complex;
use rustfft::FftPlanner;
use std::time::{Duration, Instant};
use crate::beat_tracking::audio::POINT_BUFFER_SIZE;
use super::audio::{AudioCaptureThread, MILLIS_PER_POINT};
use super::metronome::Metronome;
pub enum TrackerMode {
AUTO,
MANUAL,
}
pub struct TrackerConfig {
pub mode: TrackerMode,
pub ac_threshold: f32,
pub zero_crossing_beat_delay: i64,
}
pub struct BeatTracker {
metronome: Metronome,
pub config: TrackerConfig,
audio_capture_thread: AudioCaptureThread,
current_beats: Option<(Instant, Instant)>,
beat_count: usize,
}
impl BeatTracker {
pub fn new(metronome_timeout: Duration) -> Self {
Self {
metronome: Metronome::new(metronome_timeout),
config: TrackerConfig {
mode: TrackerMode::AUTO,
ac_threshold: 1000.0,
zero_crossing_beat_delay: 0,
},
audio_capture_thread: AudioCaptureThread::new(),
current_beats: None,
beat_count: 0,
}
}
pub fn tap(&mut self) {
self.metronome.tap();
}
pub fn get_graph_points(&self) -> (Vec<f32>, Vec<f32>) {
let points = self.audio_capture_thread.get_points();
let mut autocorrelation = points
.iter()
.map(|point| Complex {
re: *point,
im: 0.0,
})
.collect::<Vec<_>>();
let n = autocorrelation.len();
for _ in 0..n {
autocorrelation.push(Complex { re: 0.0, im: 0.0 });
}
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(2 * n);
let ifft = planner.plan_fft_inverse(2 * n);
fft.process(&mut autocorrelation);
for f in autocorrelation.iter_mut() {
*f = Complex {
re: f.norm_sqr(),
im: 0.0,
};
}
ifft.process(&mut autocorrelation);
let mut autocorrelation: Vec<f32> = autocorrelation[..n].iter().map(|x| x.re).collect();
// remove first peak
let mut j = 0;
let i = loop {
j += 1;
if j == n {
break None;
}
if autocorrelation[j - 1] <= autocorrelation[j] {
break Some(j);
}
};
if let Some(i) = i {
let first_increase = autocorrelation[i];
for f in autocorrelation[..i].iter_mut() {
*f = first_increase;
}
}
(points, autocorrelation)
}
// (period_length [ms], crossing [point index, timing relative to last_update])
fn get_correlation_data(&self) -> Option<(usize, usize)> {
let (points, autocorrelation) = self.get_graph_points();
if autocorrelation
.iter()
.reduce(|a, b| if a > b { a } else { b })
.unwrap()
< &self.config.ac_threshold
{
return None;
}
let mut j = 0;
let i = loop {
j += 1;
if j == autocorrelation.len() {
break None;
}
if autocorrelation[j - 1] > autocorrelation[j] {
break Some(j);
}
};
if i.is_none() {
return None;
}
let period_length = i.unwrap();
// find zero crossing
let mut crossing = None;
for i in 1..points.len() {
if points[i - 1].signum() < 0.0 && points[i].signum() > 0.0 {
crossing = Some(i);
}
}
if crossing.is_none() {
return None;
}
let crossing = crossing.unwrap();
Some((period_length, crossing))
}
pub fn current_beat_progress(&mut self) -> Option<f64> {
let now = Instant::now();
match self.config.mode {
TrackerMode::AUTO => {
// println!("reee");
let should_update = {
match self.current_beats {
None => true,
Some((_, next)) if now > next => {
self.beat_count += 1;
true
}
_ => false,
}
};
if should_update {
if let Some((period_length, crossing)) = self.get_correlation_data() {
let last_update_timestamp = self.audio_capture_thread.get_last_update();
let mut dt = (now - last_update_timestamp).as_millis() as i64;
dt += ((POINT_BUFFER_SIZE - crossing) * MILLIS_PER_POINT) as i64;
let mut prev = now - Duration::from_millis(dt as u64);
let period_millis = (period_length * MILLIS_PER_POINT) as u64;
let period = Duration::from_millis(period_millis);
let mut next = prev + period;
if let Some((_, old_next)) = self.current_beats {
while next < old_next + Duration::from_millis(period_millis / 2) {
prev += period;
next += period;
}
}
self.current_beats = Some((prev, next));
} else {
self.beat_count = 0;
return None;
}
}
let (prev, next) = self.current_beats.unwrap();
let mut relative_millis = if prev < now {
(now - prev).as_millis() as f64
} else {
-1.0 * (prev - now).as_millis() as f64
};
relative_millis += self.config.zero_crossing_beat_delay as f64;
let fractional = relative_millis / (next - prev).as_millis() as f64;
Some(self.beat_count as f64 + fractional)
}
TrackerMode::MANUAL => self.metronome.get_beat_progress(now),
}
}
pub fn stop(&mut self) -> Result<()> {
self.audio_capture_thread.stop()
}
}