diff --git a/boilerbloat/src/main/backend.ts b/boilerbloat/src/main/backend.ts index d897f1c..dc458c0 100644 --- a/boilerbloat/src/main/backend.ts +++ b/boilerbloat/src/main/backend.ts @@ -6,11 +6,15 @@ import { TestPattern } from '../patterns/test'; import rust, { BeatTrackerHandle, MovingHeadState, OutputHandle } from 'rust_native_module'; import { ChaserPattern } from '../patterns/chaser'; -type AppState = { +export type AppState = { patterns: { [key: string]: PatternOutput }, selectedPattern: string | null, beatProgress: number | null, + graphData: { + bassFiltered: Array, + autoCorrelated: Array, + } | null, }; class Backend { @@ -61,8 +65,13 @@ class Backend { patterns: {}, selectedPattern: null, beatProgress: null, + graphData: null, } + ipcMain.on('pattern-select', async (_, arg) => { + this.state.selectedPattern = arg; + }); + let time: Time = { absolute: 0, beatRelative: this.state.beatProgress, @@ -93,6 +102,8 @@ class Backend { this.state.beatProgress = null; } + this.state.graphData = this.beatTracker.getGraphPoints(); + let date = new Date(); let time: Time = { absolute: date.getTime() / 1000, diff --git a/boilerbloat/src/main/main.ts b/boilerbloat/src/main/main.ts index 70af051..2f47bbb 100644 --- a/boilerbloat/src/main/main.ts +++ b/boilerbloat/src/main/main.ts @@ -40,7 +40,7 @@ const createWindow = async () => { }, }); - mainWindow.removeMenu(); + // mainWindow.removeMenu(); mainWindow.loadURL(resolveHtmlPath('index.html')); diff --git a/boilerbloat/src/patterns/chaser.ts b/boilerbloat/src/patterns/chaser.ts index 6f7ce88..c5d926f 100644 --- a/boilerbloat/src/patterns/chaser.ts +++ b/boilerbloat/src/patterns/chaser.ts @@ -2,13 +2,31 @@ import { MovingHeadState } from 'rust_native_module'; import { Pattern, PatternOutput, Time } from './proto'; export class ChaserPattern implements Pattern { + + lastBeat: number; + lastTime: number; + + constructor() { + this.lastBeat = 0; + this.lastTime = 0; + } + render(time: Time): PatternOutput { if (time.beatRelative === null) { + this.lastBeat = 0; return null; } - let head_number = Math.ceil(time.beatRelative) % 4; + let t = time.beatRelative; + + if (t < this.lastTime) { + this.lastBeat += 1; + } + + this.lastTime = t; + + let head_number = this.lastBeat % 4; let template: MovingHeadState = { startAddress: 0, diff --git a/boilerbloat/src/renderer/App.css b/boilerbloat/src/renderer/App.css index e69de29..22156d5 100644 --- a/boilerbloat/src/renderer/App.css +++ b/boilerbloat/src/renderer/App.css @@ -0,0 +1,4 @@ +body { + color: #c9c9c9; + background-color: #222222; +} diff --git a/boilerbloat/src/renderer/App.tsx b/boilerbloat/src/renderer/App.tsx index f0d89fc..28aa9f5 100644 --- a/boilerbloat/src/renderer/App.tsx +++ b/boilerbloat/src/renderer/App.tsx @@ -4,15 +4,19 @@ import { IpcRenderer } from 'electron/renderer'; import './App.css'; import { useEffect, useState } from 'react'; +import { AppState } from '../main/backend'; +import PatternPreview from './PatternPreview'; +import GraphVisualization from './Graph'; + const ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer; function tap() { ipcRenderer.send("beat-tracking", "tap"); } -const Frontend: React.FC = () => { +const FrontendRoot: React.FC = () => { - const [state, setState] = useState(); + const [state, setState] = useState(); const pollMain = async () => { const reply = await ipcRenderer.invoke("poll"); @@ -25,20 +29,52 @@ const Frontend: React.FC = () => { }); return <> -
- State: {state ? JSON.stringify(state) : "undef"} -
+ { + state + ? + :
oops
+ } + ; +}; + + +const Frontend: React.FC<{ state: AppState }> = ({ state }) => { + + return <>
+
+ {Object.entries(state.patterns).map(([patternId, output]) => ( + + ))} +
+ { + state.graphData + ?
+

Bass Filtered

+

+ +

+

Autocorrelation

+

+ +

+
+ :
no graph data
+ } +
+ {JSON.stringify(state)} +
; -}; +} + export default function App() { return ( - + ); diff --git a/boilerbloat/src/renderer/Graph.tsx b/boilerbloat/src/renderer/Graph.tsx new file mode 100644 index 0000000..52b9057 --- /dev/null +++ b/boilerbloat/src/renderer/Graph.tsx @@ -0,0 +1,45 @@ +import { useEffect, useRef } from "react"; + +const GraphVisualization: React.FC<{ points: Array, min?: number, max?: number }> = ({ points, min, max }) => { + const canvasRef = useRef(null); + + const minY = min ? min : Math.min(...points); + const maxY = max ? max : Math.max(...points); + + useEffect(() => { + + const canvas = canvasRef.current; + if (!canvas) { + return; + } + + const ctx = canvas.getContext('2d'); + if (!ctx) { + return; + } + + const backgroundColor = '#333333'; + const foregroundColor = '#FFA500'; + + // clear + + ctx.fillStyle = backgroundColor; + ctx.fillRect(0, 0, ctx.canvas.width, ctx.canvas.height); + + // draw points + + ctx.fillStyle = foregroundColor; + for (const [x, f] of points.entries()) { + const yFract = 1 - (f - minY) / (maxY - minY); + + const y = yFract * ctx.canvas.height; + + ctx.fillRect(x, y, 1, 1); + } + + }, [points]); + + return <>{`${minY.toFixed(2)} - ${maxY.toFixed(2)}`} +} + +export default GraphVisualization; diff --git a/boilerbloat/src/renderer/PatternPreview.tsx b/boilerbloat/src/renderer/PatternPreview.tsx new file mode 100644 index 0000000..dc9152c --- /dev/null +++ b/boilerbloat/src/renderer/PatternPreview.tsx @@ -0,0 +1,14 @@ +import { IpcRenderer } from "electron/renderer"; +import { PatternOutput } from "../patterns/proto"; + +const ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer; + +const PatternPreview: React.FC<{ patternId: string, output: PatternOutput }> = ({ patternId }) => { + return ; +} + +export default PatternPreview; diff --git a/rust_native_module/Cargo.lock b/rust_native_module/Cargo.lock index 31e5d68..6346fe0 100644 --- a/rust_native_module/Cargo.lock +++ b/rust_native_module/Cargo.lock @@ -38,6 +38,12 @@ version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee10e43ae4a853c0a3591d4e2ada1719e553be18199d9da9d4a83f5927c2f5c7" +[[package]] +name = "array-init" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6945cc5422176fc5e602e590c2878d2c2acd9a4fe20a4baa7c28022521698ec6" + [[package]] name = "autocfg" version = "1.0.1" @@ -90,6 +96,54 @@ dependencies = [ "winapi", ] +[[package]] +name = "libpulse-binding" +version = "2.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86835d7763ded6bc16b6c0061ec60214da7550dfcd4ef93745f6f0096129676a" +dependencies = [ + "bitflags", + "libc", + "libpulse-sys", + "num-derive", + "num-traits", + "winapi", +] + +[[package]] +name = "libpulse-simple-binding" +version = "2.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6a22538257c4d522bea6089d6478507f5d2589ea32150e20740aaaaaba44590" +dependencies = [ + "libpulse-binding", + "libpulse-simple-sys", + "libpulse-sys", +] + +[[package]] +name = "libpulse-simple-sys" +version = "1.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b8b0fcb9665401cc7c156c337c8edc7eb4e797b9d3ae1667e1e9e17b29e0c7c" +dependencies = [ + "libpulse-sys", + "pkg-config", +] + +[[package]] +name = "libpulse-sys" +version = "1.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f12950b69c1b66233a900414befde36c8d4ea49deec1e1f34e4cd2f586e00c7d" +dependencies = [ + "libc", + "num-derive", + "num-traits", + "pkg-config", + "winapi", +] + [[package]] name = "libudev" version = "0.2.0" @@ -222,6 +276,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -270,6 +335,15 @@ version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12295df4f294471248581bc09bef3c38a5e46f1e36d6a37353621a0c6c357e1f" +[[package]] +name = "primal-check" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01419cee72c1a1ca944554e23d83e483e1bccf378753344e881de28b5487511d" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro2" version = "1.0.32" @@ -305,16 +379,43 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "ringbuffer" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20e49d3a791d79aa7683f8798b274073140865ffb9d65767ace44229d34299e3" +dependencies = [ + "array-init", +] + [[package]] name = "rust_native_module" version = "0.1.0" dependencies = [ "anyhow", + "libpulse-binding", + "libpulse-simple-binding", "neon", "num", + "ringbuffer", + "rustfft", "serialport", ] +[[package]] +name = "rustfft" +version = "6.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1d089e5c57521629a59f5f39bca7434849ff89bd6873b521afe389c1c602543" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "semver" version = "0.9.0" @@ -353,6 +454,12 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" +[[package]] +name = "strength_reduce" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ff2f71c82567c565ba4b3009a9350a96a7269eaa4001ebedae926230bc2254" + [[package]] name = "syn" version = "1.0.81" @@ -364,6 +471,16 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "transpose" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95f9c900aa98b6ea43aee227fd680550cdec726526aab8ac801549eadb25e39f" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "unicode-xid" version = "0.2.2" diff --git a/rust_native_module/Cargo.toml b/rust_native_module/Cargo.toml index f22bd2a..a015624 100644 --- a/rust_native_module/Cargo.toml +++ b/rust_native_module/Cargo.toml @@ -15,6 +15,10 @@ crate-type = ["cdylib"] serialport = "4.0.1" anyhow = "1.0.45" num = "0.4.0" +pulse = { version = "2.25.0", package = "libpulse-binding" } +psimple = { version = "2.24.1", package = "libpulse-simple-binding" } +ringbuffer = "0.8.3" +rustfft = "6.0.1" [dependencies.neon] version = "0.9" diff --git a/rust_native_module/index.d.ts b/rust_native_module/index.d.ts index 99a7363..dd2973e 100644 --- a/rust_native_module/index.d.ts +++ b/rust_native_module/index.d.ts @@ -30,9 +30,16 @@ declare module rust_native_module { close: () => Result, } + type GraphPoints = { + bassFiltered: Array, + autoCorrelated: Array, + } + type BeatTrackerHandle = { tap: () => void, getProgress: () => Option, + stop: () => Result, + getGraphPoints: () => GraphPoints, } function listPorts(): Array; diff --git a/rust_native_module/src/beat_tracking/audio/capture.rs b/rust_native_module/src/beat_tracking/audio/capture.rs new file mode 100644 index 0000000..d7056e3 --- /dev/null +++ b/rust_native_module/src/beat_tracking/audio/capture.rs @@ -0,0 +1,95 @@ +use anyhow::{anyhow, Result}; +use psimple::Simple; +use pulse::context::{Context, FlagSet as ContextFlagSet}; +use pulse::mainloop::standard::{IterateResult, Mainloop}; +use pulse::sample::Spec; +use pulse::stream::Direction; +use std::cell::RefCell; +use std::ops::Deref; +use std::rc::Rc; + +/* + Some manual poking around in PulseAudio to get the name of the default sink +*/ + +fn poll_mainloop(mainloop: &Rc>) -> Result<()> { + match mainloop.borrow_mut().iterate(true) { + IterateResult::Quit(_) | IterateResult::Err(_) => { + return Err(anyhow!("Iterate state was not success")); + } + IterateResult::Success(_) => { + return Ok(()); + } + } +} + +fn get_pulse_default_sink() -> Result { + let mainloop = Rc::new(RefCell::new( + Mainloop::new().ok_or(anyhow!("Failed to create mainloop"))?, + )); + let ctx = Rc::new(RefCell::new( + Context::new(mainloop.borrow().deref(), "gib_default_sink") + .ok_or(anyhow!("Failed to create context"))?, + )); + + ctx.borrow_mut() + .connect(None, ContextFlagSet::NOFLAGS, None)?; + + // Wait for context to be ready + loop { + poll_mainloop(&mainloop)?; + + match ctx.borrow().get_state() { + pulse::context::State::Ready => { + break; + } + pulse::context::State::Failed | pulse::context::State::Terminated => { + return Err(anyhow!("Context was in failed/terminated state")); + } + _ => {} + } + } + + let result = Rc::new(RefCell::new(None)); + let cb_result_ref = result.clone(); + + ctx.borrow().introspect().get_server_info(move |info| { + *cb_result_ref.borrow_mut() = if let Some(ref sink_name) = info.default_sink_name { + Some(Ok(sink_name.to_string())) + } else { + Some(Err(())) + } + }); + + loop { + if let Some(result) = result.borrow().deref() { + return result + .to_owned() + .map_err(|_| anyhow!("Default sink name was empty")); + } + + poll_mainloop(&mainloop)?; + } +} + +/* + Get a PASimple instance which reads from the default sink monitor +*/ + +pub fn get_audio_reader(spec: &Spec) -> Result { + let mut default_sink_name = get_pulse_default_sink()?; + default_sink_name.push_str(".monitor"); + + let simple = Simple::new( + None, + "piano_thingy", + Direction::Record, + Some(&default_sink_name), + "sample_yoinker", + &spec, + None, + None, + )?; + + Ok(simple) +} diff --git a/rust_native_module/src/beat_tracking/audio/dsp.rs b/rust_native_module/src/beat_tracking/audio/dsp.rs new file mode 100644 index 0000000..1c46c30 --- /dev/null +++ b/rust_native_module/src/beat_tracking/audio/dsp.rs @@ -0,0 +1,71 @@ +/* + Taken from Till's magic Arduino Sketch +*/ +pub trait ZTransformFilter { + fn process(&mut self, sample: f32) -> f32; +} + +// 20 - 200Hz Single Pole Bandpass IIR Filter +#[derive(Default)] +pub struct BassFilter { + xv: [f32; 3], + yv: [f32; 3], +} + +impl ZTransformFilter for BassFilter { + fn process(&mut self, sample: f32) -> f32 { + self.xv[0] = self.xv[1]; + self.xv[1] = self.xv[2]; + self.xv[2] = sample / 3.0f32; + + self.yv[0] = self.yv[1]; + self.yv[1] = self.yv[2]; + self.yv[2] = (self.xv[2] - self.xv[0]) + + (-0.7960060012f32 * self.yv[0]) + + (1.7903124146f32 * self.yv[1]); + + self.yv[2] + } +} + +// 10Hz Single Pole Lowpass IIR Filter +#[derive(Default)] +pub struct EnvelopeFilter { + xv: [f32; 2], + yv: [f32; 2], +} + +impl ZTransformFilter for EnvelopeFilter { + fn process(&mut self, sample: f32) -> f32 { + self.xv[0] = self.xv[1]; + self.xv[1] = sample / 50.0f32; + + self.yv[0] = self.yv[1]; + self.yv[1] = (self.xv[0] + self.xv[1]) + (0.9875119299f32 * self.yv[0]); + + self.yv[1] + } +} + +// 1.7 - 3.0Hz Single Pole Bandpass IIR Filter +#[derive(Default)] +pub struct BeatFilter { + xv: [f32; 3], + yv: [f32; 3], +} + +impl ZTransformFilter for BeatFilter { + fn process(&mut self, sample: f32) -> f32 { + self.xv[0] = self.xv[1]; + self.xv[1] = self.xv[2]; + self.xv[2] = sample / 2.7f32; + + self.yv[0] = self.yv[1]; + self.yv[1] = self.yv[2]; + self.yv[2] = (self.xv[2] - self.xv[0]) + + (-0.7169861741f32 * self.yv[0]) + + (1.4453653501f32 * self.yv[1]); + + self.yv[2] + } +} diff --git a/rust_native_module/src/beat_tracking/audio/mod.rs b/rust_native_module/src/beat_tracking/audio/mod.rs new file mode 100644 index 0000000..b5440b8 --- /dev/null +++ b/rust_native_module/src/beat_tracking/audio/mod.rs @@ -0,0 +1,141 @@ +use anyhow::{anyhow, Result}; +use ringbuffer::{ConstGenericRingBuffer, RingBufferExt, RingBufferWrite}; +use std::{ + sync::{Arc, Mutex}, + thread::{self, JoinHandle}, + time::Instant, +}; + +use pulse::sample::{Format, Spec}; + +use self::dsp::ZTransformFilter; + +mod capture; +mod dsp; + +const SAMPLE_RATE: usize = 5000; +const PULSE_UPDATES_PER_SECOND: usize = 50; +const BUFFER_SIZE: usize = SAMPLE_RATE / PULSE_UPDATES_PER_SECOND; + +const POINTS_PER_SECOND: usize = SAMPLE_RATE / 200; +pub const MILLIS_PER_POINT: usize = 1000 / POINTS_PER_SECOND; +const POINT_MIN_COUNT: usize = 2 * POINTS_PER_SECOND; +const POINT_BUFFER_SIZE: usize = POINT_MIN_COUNT.next_power_of_two(); + +pub struct AudioCaptureThread { + join_handle: Option>>, + shared_state: Arc>, +} + +impl AudioCaptureThread { + pub fn new() -> Self { + let shared_state = Arc::new(Mutex::new(SharedState::new())); + let join_handle = { + let shared_state = shared_state.clone(); + Some(thread::spawn(move || audio_capture_thread(shared_state))) + }; + + Self { + shared_state, + join_handle, + } + } + + pub fn stop(&mut self) -> Result<()> { + if self.join_handle.is_none() { + return Err(anyhow!( + "join_handle was none, was stop called multiple times?" + )); + } else { + let handle = self.join_handle.take().unwrap(); + + handle.join().unwrap() + } + } + + pub fn read_state(&self) -> (Vec, Instant) { + let state = self.shared_state.lock().unwrap(); + (state.point_buf.to_vec(), state.last_update) + } + + pub fn get_points(&self) -> Vec { + let state = self.shared_state.lock().unwrap(); + state.point_buf.to_vec() + } + + pub fn get_last_update(&self) -> Instant { + let state = self.shared_state.lock().unwrap(); + state.last_update + } +} + +struct SharedState { + running: bool, + point_buf: ConstGenericRingBuffer, + last_update: Instant, +} + +impl SharedState { + pub fn new() -> Self { + let mut point_buf: ConstGenericRingBuffer = Default::default(); + point_buf.fill_default(); + + Self { + running: true, + point_buf, + last_update: Instant::now(), + } + } +} + +fn audio_capture_thread(state: Arc>) -> Result<()> { + let spec = Spec { + format: Format::F32le, + rate: SAMPLE_RATE as u32, + channels: 1, + }; + + let reader = capture::get_audio_reader(&spec)?; + let mut buffer = [0u8; 4 * BUFFER_SIZE]; + + let mut bass_filter = dsp::BassFilter::default(); + let mut envelope_filter = dsp::EnvelopeFilter::default(); + let mut beat_filter = dsp::BeatFilter::default(); + let mut j = 0; + + loop { + { + if state.lock().unwrap().running == false { + break Ok(()); + } + } + + reader.read(&mut buffer)?; + + for i in 0..BUFFER_SIZE { + let mut float_bytes = [0u8; 4]; + float_bytes.copy_from_slice(&buffer[4 * i..4 * i + 4]); + + j += 1; + let sample = f32::from_le_bytes(float_bytes); + let mut value = bass_filter.process(sample); + + if value < 0f32 { + value = -value; + } + + let envelope = envelope_filter.process(value); + + if j == 200 { + let beat = beat_filter.process(envelope); + + let mut state = state.lock().unwrap(); + + state.point_buf.push(beat); + state.last_update = Instant::now(); + + j = 0; + } + } + } +} diff --git a/rust_native_module/src/beat_tracking/metronome.rs b/rust_native_module/src/beat_tracking/metronome.rs index f60e3d8..c4dc1f9 100644 --- a/rust_native_module/src/beat_tracking/metronome.rs +++ b/rust_native_module/src/beat_tracking/metronome.rs @@ -1,4 +1,7 @@ -use std::time::{Duration, Instant}; +use std::{ + ops::RangeInclusive, + time::{Duration, Instant}, +}; pub struct Metronome { taps: Vec, @@ -34,14 +37,35 @@ impl Metronome { } } - pub fn current_beat_progress(&self) -> Option { + pub fn get_beat_progress(&self, now: Instant) -> Option { if self.beat_interval.is_none() { return None; } - let now = Instant::now(); let relative_millis = (now - *self.taps.last().unwrap()).as_millis(); Some(relative_millis as f64 / self.beat_interval.unwrap() as f64) } + + pub fn get_beats(&self, from: Instant, to: Instant) -> Vec<(u64, Instant)> { + if self.beat_interval.is_none() { + return Vec::new(); + } + + let beat_interval = self.beat_interval.unwrap(); + + let last = *self.taps.last().unwrap(); + + let relative_from = (from - last).as_millis() as f64 / beat_interval as f64; + let relative_to = (to - last).as_millis() as f64 / beat_interval as f64; + + (relative_from.ceil() as u64..=relative_to.floor() as u64) + .map(|beat_number| { + ( + beat_number, + last + Duration::from_millis((beat_number as u128 * beat_interval) as u64), + ) + }) + .collect() + } } diff --git a/rust_native_module/src/beat_tracking/mod.rs b/rust_native_module/src/beat_tracking/mod.rs index ca3e6cd..81a4082 100644 --- a/rust_native_module/src/beat_tracking/mod.rs +++ b/rust_native_module/src/beat_tracking/mod.rs @@ -1,19 +1,21 @@ +mod audio; mod metronome; +mod tracker; use std::{cell::RefCell, time::Duration}; -use metronome::Metronome; +use tracker::BeatTracker; use neon::prelude::*; -type BoxedTracker = JsBox>; -impl Finalize for Metronome {} +type BoxedTracker = JsBox>; +impl Finalize for BeatTracker {} pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult { let obj = cx.empty_object(); let value_obj = cx.empty_object(); - let boxed_tracker = cx.boxed(RefCell::new(Metronome::new(Duration::from_secs(2)))); + let boxed_tracker = cx.boxed(RefCell::new(BeatTracker::new(Duration::from_secs(2)))); value_obj.set(&mut cx, "_rust_ptr", boxed_tracker)?; let tap_function = JsFunction::new(&mut cx, tap)?; @@ -22,6 +24,12 @@ pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult { let get_progress_function = JsFunction::new(&mut cx, get_progress)?; value_obj.set(&mut cx, "getProgress", get_progress_function)?; + let close_function = JsFunction::new(&mut cx, close)?; + value_obj.set(&mut cx, "close", close_function)?; + + let get_graph_points_function = JsFunction::new(&mut cx, get_graph_points)?; + value_obj.set(&mut cx, "getGraphPoints", get_graph_points_function)?; + obj.set(&mut cx, "value", value_obj)?; let success_string = cx.string("success".to_string()); @@ -68,3 +76,57 @@ pub fn get_progress(mut cx: FunctionContext) -> JsResult { Ok(obj) } + +fn get_graph_points(mut cx: FunctionContext) -> JsResult { + let this = cx.this(); + + let boxed_tracker = this + .get(&mut cx, "_rust_ptr")? + .downcast_or_throw::(&mut cx)?; + + let obj = cx.empty_object(); + + let (bass_filtered, auto_correlated) = boxed_tracker.borrow().get_graph_points(); + + let arr_bf = cx.empty_array(); + for (i, f) in bass_filtered.iter().enumerate() { + let num = cx.number(*f); + arr_bf.set(&mut cx, i as u32, num)?; + } + obj.set(&mut cx, "bassFiltered", arr_bf)?; + + let arr_ac = cx.empty_array(); + for (i, f) in auto_correlated.iter().enumerate() { + let num = cx.number(*f); + arr_ac.set(&mut cx, i as u32, num)?; + } + obj.set(&mut cx, "autoCorrelated", arr_ac)?; + + Ok(obj) +} + +pub fn close(mut cx: FunctionContext) -> JsResult { + let this = cx.this(); + + let boxed_tracker = this + .get(&mut cx, "_rust_ptr")? + .downcast_or_throw::(&mut cx)?; + + let obj = cx.empty_object(); + let success_string; + + match boxed_tracker.borrow_mut().stop() { + Ok(()) => { + success_string = cx.string("success"); + } + Err(e) => { + success_string = cx.string("error"); + + let error_message = cx.string(e.to_string()); + obj.set(&mut cx, "message", error_message)?; + } + } + obj.set(&mut cx, "type", success_string)?; + + Ok(obj) +} diff --git a/rust_native_module/src/beat_tracking/tracker.rs b/rust_native_module/src/beat_tracking/tracker.rs new file mode 100644 index 0000000..d4d7b1b --- /dev/null +++ b/rust_native_module/src/beat_tracking/tracker.rs @@ -0,0 +1,167 @@ +use anyhow::Result; +use num::Complex; +use rustfft::FftPlanner; +use std::time::{Duration, Instant}; + +use super::audio::{AudioCaptureThread, MILLIS_PER_POINT}; +use super::metronome::Metronome; + +pub struct TrackerConfig { + ac_threshold: f32, + zero_crossing_beat_delay: i64, +} + +pub struct BeatTracker { + metronome: Metronome, + config: TrackerConfig, + audio_capture_thread: AudioCaptureThread, +} + +impl BeatTracker { + pub fn new(metronome_timeout: Duration) -> Self { + Self { + metronome: Metronome::new(metronome_timeout), + config: TrackerConfig { + ac_threshold: 1000.0, + zero_crossing_beat_delay: 0, + }, + audio_capture_thread: AudioCaptureThread::new(), + } + } + + pub fn tap(&mut self) { + self.metronome.tap(); + } + + pub fn get_graph_points(&self) -> (Vec, Vec) { + let points = self.audio_capture_thread.get_points(); + + let mut autocorrelation = points + .iter() + .map(|point| Complex { + re: *point, + im: 0.0, + }) + .collect::>(); + + let n = autocorrelation.len(); + + for _ in 0..n { + autocorrelation.push(Complex { re: 0.0, im: 0.0 }); + } + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(2 * n); + let ifft = planner.plan_fft_inverse(2 * n); + + fft.process(&mut autocorrelation); + + for f in autocorrelation.iter_mut() { + *f = Complex { + re: f.norm_sqr(), + im: 0.0, + }; + } + + ifft.process(&mut autocorrelation); + + let mut autocorrelation: Vec = autocorrelation[..n].iter().map(|x| x.re).collect(); + + // remove first peak + + let mut j = 0; + let i = loop { + j += 1; + + if j == n { + break None; + } + + if autocorrelation[j - 1] <= autocorrelation[j] { + break Some(j); + } + }; + + if let Some(i) = i { + let first_increase = autocorrelation[i]; + + for f in autocorrelation[..i].iter_mut() { + *f = first_increase; + } + } + + (points, autocorrelation) + } + + pub fn current_beat_progress(&self) -> Option { + let (points, autocorrelation) = self.get_graph_points(); + + if autocorrelation + .iter() + .reduce(|a, b| if a > b { a } else { b }) + .unwrap() + < &self.config.ac_threshold + { + return None; + } + + let mut j = 0; + let i = loop { + j += 1; + + if j == autocorrelation.len() { + break None; + } + + if autocorrelation[j - 1] > autocorrelation[j] { + break Some(j); + } + }; + + if i.is_none() { + return None; + } + + let period_length = i.unwrap(); + + // println!("predicted period length: {}", period_length); + + // find zero crossing + + let mut crossing = None; + + for i in 1..points.len() { + if points[i-1].signum() < 0.0 && points[i].signum() > 0.0 { + crossing = Some(i); + } + } + + if crossing.is_none() { + return None; + } + + let crossing = crossing.unwrap(); + // println!("index of last positive zero crossing: {}", crossing); + + let last_update_timestamp = self.audio_capture_thread.get_last_update(); + let now = Instant::now(); + + let mut dt = (now - last_update_timestamp).as_millis() as i64; + dt += ((points.len() - crossing) * MILLIS_PER_POINT) as i64; + dt += self.config.zero_crossing_beat_delay; + + let dt = dt as f64; + + let period_millis = (period_length * MILLIS_PER_POINT) as f64; + + let relative_time = (dt % period_millis) / period_millis; + + println!("relative time: {}", relative_time); + + Some(relative_time) + } + + pub fn stop(&mut self) -> Result<()> { + self.audio_capture_thread.stop() + } +}