Implement Autocorrelation

This commit is contained in:
Kai Vogelgesang 2021-11-12 05:44:37 +01:00
parent abb0476a17
commit 41b54f6209
Signed by: kai
GPG Key ID: 0A95D3B6E62C0879
16 changed files with 833 additions and 17 deletions

View File

@ -6,11 +6,15 @@ import { TestPattern } from '../patterns/test';
import rust, { BeatTrackerHandle, MovingHeadState, OutputHandle } from 'rust_native_module'; import rust, { BeatTrackerHandle, MovingHeadState, OutputHandle } from 'rust_native_module';
import { ChaserPattern } from '../patterns/chaser'; import { ChaserPattern } from '../patterns/chaser';
type AppState = { export type AppState = {
patterns: { [key: string]: PatternOutput }, patterns: { [key: string]: PatternOutput },
selectedPattern: string | null, selectedPattern: string | null,
beatProgress: number | null, beatProgress: number | null,
graphData: {
bassFiltered: Array<number>,
autoCorrelated: Array<number>,
} | null,
}; };
class Backend { class Backend {
@ -61,8 +65,13 @@ class Backend {
patterns: {}, patterns: {},
selectedPattern: null, selectedPattern: null,
beatProgress: null, beatProgress: null,
graphData: null,
} }
ipcMain.on('pattern-select', async (_, arg) => {
this.state.selectedPattern = arg;
});
let time: Time = { let time: Time = {
absolute: 0, absolute: 0,
beatRelative: this.state.beatProgress, beatRelative: this.state.beatProgress,
@ -93,6 +102,8 @@ class Backend {
this.state.beatProgress = null; this.state.beatProgress = null;
} }
this.state.graphData = this.beatTracker.getGraphPoints();
let date = new Date(); let date = new Date();
let time: Time = { let time: Time = {
absolute: date.getTime() / 1000, absolute: date.getTime() / 1000,

View File

@ -40,7 +40,7 @@ const createWindow = async () => {
}, },
}); });
mainWindow.removeMenu(); // mainWindow.removeMenu();
mainWindow.loadURL(resolveHtmlPath('index.html')); mainWindow.loadURL(resolveHtmlPath('index.html'));

View File

@ -2,13 +2,31 @@ import { MovingHeadState } from 'rust_native_module';
import { Pattern, PatternOutput, Time } from './proto'; import { Pattern, PatternOutput, Time } from './proto';
export class ChaserPattern implements Pattern { export class ChaserPattern implements Pattern {
lastBeat: number;
lastTime: number;
constructor() {
this.lastBeat = 0;
this.lastTime = 0;
}
render(time: Time): PatternOutput { render(time: Time): PatternOutput {
if (time.beatRelative === null) { if (time.beatRelative === null) {
this.lastBeat = 0;
return null; return null;
} }
let head_number = Math.ceil(time.beatRelative) % 4; let t = time.beatRelative;
if (t < this.lastTime) {
this.lastBeat += 1;
}
this.lastTime = t;
let head_number = this.lastBeat % 4;
let template: MovingHeadState = { let template: MovingHeadState = {
startAddress: 0, startAddress: 0,

View File

@ -0,0 +1,4 @@
body {
color: #c9c9c9;
background-color: #222222;
}

View File

@ -4,15 +4,19 @@ import { IpcRenderer } from 'electron/renderer';
import './App.css'; import './App.css';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { AppState } from '../main/backend';
import PatternPreview from './PatternPreview';
import GraphVisualization from './Graph';
const ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer; const ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer;
function tap() { function tap() {
ipcRenderer.send("beat-tracking", "tap"); ipcRenderer.send("beat-tracking", "tap");
} }
const Frontend: React.FC = () => { const FrontendRoot: React.FC = () => {
const [state, setState] = useState<any>(); const [state, setState] = useState<AppState>();
const pollMain = async () => { const pollMain = async () => {
const reply = await ipcRenderer.invoke("poll"); const reply = await ipcRenderer.invoke("poll");
@ -25,20 +29,52 @@ const Frontend: React.FC = () => {
}); });
return <> return <>
<div> {
State: {state ? JSON.stringify(state) : "undef"} state
</div> ? <Frontend state={state} />
: <div> oops </div>
}
</>;
};
const Frontend: React.FC<{ state: AppState }> = ({ state }) => {
return <>
<div> <div>
<button onClick={tap}>Tap</button> <button onClick={tap}>Tap</button>
</div> </div>
<div>
{Object.entries(state.patterns).map(([patternId, output]) => (
<PatternPreview key={patternId} patternId={patternId} output={output} />
))}
</div>
{
state.graphData
? <div>
<p> Bass Filtered </p>
<p>
<GraphVisualization points={state.graphData.bassFiltered} />
</p>
<p> Autocorrelation </p>
<p>
<GraphVisualization points={state.graphData.autoCorrelated} />
</p>
</div>
: <div> no graph data </div>
}
<div>
{JSON.stringify(state)}
</div>
</>; </>;
}; }
export default function App() { export default function App() {
return ( return (
<Router> <Router>
<Switch> <Switch>
<Route path="/" component={Frontend} /> <Route path="/" component={FrontendRoot} />
</Switch> </Switch>
</Router> </Router>
); );

View File

@ -0,0 +1,45 @@
import { useEffect, useRef } from "react";
const GraphVisualization: React.FC<{ points: Array<number>, min?: number, max?: number }> = ({ points, min, max }) => {
const canvasRef = useRef<HTMLCanvasElement>(null);
const minY = min ? min : Math.min(...points);
const maxY = max ? max : Math.max(...points);
useEffect(() => {
const canvas = canvasRef.current;
if (!canvas) {
return;
}
const ctx = canvas.getContext('2d');
if (!ctx) {
return;
}
const backgroundColor = '#333333';
const foregroundColor = '#FFA500';
// clear
ctx.fillStyle = backgroundColor;
ctx.fillRect(0, 0, ctx.canvas.width, ctx.canvas.height);
// draw points
ctx.fillStyle = foregroundColor;
for (const [x, f] of points.entries()) {
const yFract = 1 - (f - minY) / (maxY - minY);
const y = yFract * ctx.canvas.height;
ctx.fillRect(x, y, 1, 1);
}
}, [points]);
return <><canvas ref={canvasRef} width={points.length} height={100} />{`${minY.toFixed(2)} - ${maxY.toFixed(2)}`}</>
}
export default GraphVisualization;

View File

@ -0,0 +1,14 @@
import { IpcRenderer } from "electron/renderer";
import { PatternOutput } from "../patterns/proto";
const ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer;
const PatternPreview: React.FC<{ patternId: string, output: PatternOutput }> = ({ patternId }) => {
return <button onClick={() => {
ipcRenderer.send("pattern-select", patternId);
}}>
{patternId}
</button>;
}
export default PatternPreview;

View File

@ -38,6 +38,12 @@ version = "1.0.45"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee10e43ae4a853c0a3591d4e2ada1719e553be18199d9da9d4a83f5927c2f5c7" checksum = "ee10e43ae4a853c0a3591d4e2ada1719e553be18199d9da9d4a83f5927c2f5c7"
[[package]]
name = "array-init"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6945cc5422176fc5e602e590c2878d2c2acd9a4fe20a4baa7c28022521698ec6"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.0.1" version = "1.0.1"
@ -90,6 +96,54 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "libpulse-binding"
version = "2.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86835d7763ded6bc16b6c0061ec60214da7550dfcd4ef93745f6f0096129676a"
dependencies = [
"bitflags",
"libc",
"libpulse-sys",
"num-derive",
"num-traits",
"winapi",
]
[[package]]
name = "libpulse-simple-binding"
version = "2.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6a22538257c4d522bea6089d6478507f5d2589ea32150e20740aaaaaba44590"
dependencies = [
"libpulse-binding",
"libpulse-simple-sys",
"libpulse-sys",
]
[[package]]
name = "libpulse-simple-sys"
version = "1.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b8b0fcb9665401cc7c156c337c8edc7eb4e797b9d3ae1667e1e9e17b29e0c7c"
dependencies = [
"libpulse-sys",
"pkg-config",
]
[[package]]
name = "libpulse-sys"
version = "1.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f12950b69c1b66233a900414befde36c8d4ea49deec1e1f34e4cd2f586e00c7d"
dependencies = [
"libc",
"num-derive",
"num-traits",
"pkg-config",
"winapi",
]
[[package]] [[package]]
name = "libudev" name = "libudev"
version = "0.2.0" version = "0.2.0"
@ -222,6 +276,17 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "num-derive"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "num-integer" name = "num-integer"
version = "0.1.44" version = "0.1.44"
@ -270,6 +335,15 @@ version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12295df4f294471248581bc09bef3c38a5e46f1e36d6a37353621a0c6c357e1f" checksum = "12295df4f294471248581bc09bef3c38a5e46f1e36d6a37353621a0c6c357e1f"
[[package]]
name = "primal-check"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01419cee72c1a1ca944554e23d83e483e1bccf378753344e881de28b5487511d"
dependencies = [
"num-integer",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.32" version = "1.0.32"
@ -305,16 +379,43 @@ version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "ringbuffer"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20e49d3a791d79aa7683f8798b274073140865ffb9d65767ace44229d34299e3"
dependencies = [
"array-init",
]
[[package]] [[package]]
name = "rust_native_module" name = "rust_native_module"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"libpulse-binding",
"libpulse-simple-binding",
"neon", "neon",
"num", "num",
"ringbuffer",
"rustfft",
"serialport", "serialport",
] ]
[[package]]
name = "rustfft"
version = "6.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1d089e5c57521629a59f5f39bca7434849ff89bd6873b521afe389c1c602543"
dependencies = [
"num-complex",
"num-integer",
"num-traits",
"primal-check",
"strength_reduce",
"transpose",
]
[[package]] [[package]]
name = "semver" name = "semver"
version = "0.9.0" version = "0.9.0"
@ -353,6 +454,12 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
[[package]]
name = "strength_reduce"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3ff2f71c82567c565ba4b3009a9350a96a7269eaa4001ebedae926230bc2254"
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.81" version = "1.0.81"
@ -364,6 +471,16 @@ dependencies = [
"unicode-xid", "unicode-xid",
] ]
[[package]]
name = "transpose"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95f9c900aa98b6ea43aee227fd680550cdec726526aab8ac801549eadb25e39f"
dependencies = [
"num-integer",
"strength_reduce",
]
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.2" version = "0.2.2"

View File

@ -15,6 +15,10 @@ crate-type = ["cdylib"]
serialport = "4.0.1" serialport = "4.0.1"
anyhow = "1.0.45" anyhow = "1.0.45"
num = "0.4.0" num = "0.4.0"
pulse = { version = "2.25.0", package = "libpulse-binding" }
psimple = { version = "2.24.1", package = "libpulse-simple-binding" }
ringbuffer = "0.8.3"
rustfft = "6.0.1"
[dependencies.neon] [dependencies.neon]
version = "0.9" version = "0.9"

View File

@ -30,9 +30,16 @@ declare module rust_native_module {
close: () => Result<never>, close: () => Result<never>,
} }
type GraphPoints = {
bassFiltered: Array<number>,
autoCorrelated: Array<number>,
}
type BeatTrackerHandle = { type BeatTrackerHandle = {
tap: () => void, tap: () => void,
getProgress: () => Option<number>, getProgress: () => Option<number>,
stop: () => Result<never>,
getGraphPoints: () => GraphPoints,
} }
function listPorts(): Array<string>; function listPorts(): Array<string>;

View File

@ -0,0 +1,95 @@
use anyhow::{anyhow, Result};
use psimple::Simple;
use pulse::context::{Context, FlagSet as ContextFlagSet};
use pulse::mainloop::standard::{IterateResult, Mainloop};
use pulse::sample::Spec;
use pulse::stream::Direction;
use std::cell::RefCell;
use std::ops::Deref;
use std::rc::Rc;
/*
Some manual poking around in PulseAudio to get the name of the default sink
*/
fn poll_mainloop(mainloop: &Rc<RefCell<Mainloop>>) -> Result<()> {
match mainloop.borrow_mut().iterate(true) {
IterateResult::Quit(_) | IterateResult::Err(_) => {
return Err(anyhow!("Iterate state was not success"));
}
IterateResult::Success(_) => {
return Ok(());
}
}
}
fn get_pulse_default_sink() -> Result<String> {
let mainloop = Rc::new(RefCell::new(
Mainloop::new().ok_or(anyhow!("Failed to create mainloop"))?,
));
let ctx = Rc::new(RefCell::new(
Context::new(mainloop.borrow().deref(), "gib_default_sink")
.ok_or(anyhow!("Failed to create context"))?,
));
ctx.borrow_mut()
.connect(None, ContextFlagSet::NOFLAGS, None)?;
// Wait for context to be ready
loop {
poll_mainloop(&mainloop)?;
match ctx.borrow().get_state() {
pulse::context::State::Ready => {
break;
}
pulse::context::State::Failed | pulse::context::State::Terminated => {
return Err(anyhow!("Context was in failed/terminated state"));
}
_ => {}
}
}
let result = Rc::new(RefCell::new(None));
let cb_result_ref = result.clone();
ctx.borrow().introspect().get_server_info(move |info| {
*cb_result_ref.borrow_mut() = if let Some(ref sink_name) = info.default_sink_name {
Some(Ok(sink_name.to_string()))
} else {
Some(Err(()))
}
});
loop {
if let Some(result) = result.borrow().deref() {
return result
.to_owned()
.map_err(|_| anyhow!("Default sink name was empty"));
}
poll_mainloop(&mainloop)?;
}
}
/*
Get a PASimple instance which reads from the default sink monitor
*/
pub fn get_audio_reader(spec: &Spec) -> Result<Simple> {
let mut default_sink_name = get_pulse_default_sink()?;
default_sink_name.push_str(".monitor");
let simple = Simple::new(
None,
"piano_thingy",
Direction::Record,
Some(&default_sink_name),
"sample_yoinker",
&spec,
None,
None,
)?;
Ok(simple)
}

View File

@ -0,0 +1,71 @@
/*
Taken from Till's magic Arduino Sketch
*/
pub trait ZTransformFilter {
fn process(&mut self, sample: f32) -> f32;
}
// 20 - 200Hz Single Pole Bandpass IIR Filter
#[derive(Default)]
pub struct BassFilter {
xv: [f32; 3],
yv: [f32; 3],
}
impl ZTransformFilter for BassFilter {
fn process(&mut self, sample: f32) -> f32 {
self.xv[0] = self.xv[1];
self.xv[1] = self.xv[2];
self.xv[2] = sample / 3.0f32;
self.yv[0] = self.yv[1];
self.yv[1] = self.yv[2];
self.yv[2] = (self.xv[2] - self.xv[0])
+ (-0.7960060012f32 * self.yv[0])
+ (1.7903124146f32 * self.yv[1]);
self.yv[2]
}
}
// 10Hz Single Pole Lowpass IIR Filter
#[derive(Default)]
pub struct EnvelopeFilter {
xv: [f32; 2],
yv: [f32; 2],
}
impl ZTransformFilter for EnvelopeFilter {
fn process(&mut self, sample: f32) -> f32 {
self.xv[0] = self.xv[1];
self.xv[1] = sample / 50.0f32;
self.yv[0] = self.yv[1];
self.yv[1] = (self.xv[0] + self.xv[1]) + (0.9875119299f32 * self.yv[0]);
self.yv[1]
}
}
// 1.7 - 3.0Hz Single Pole Bandpass IIR Filter
#[derive(Default)]
pub struct BeatFilter {
xv: [f32; 3],
yv: [f32; 3],
}
impl ZTransformFilter for BeatFilter {
fn process(&mut self, sample: f32) -> f32 {
self.xv[0] = self.xv[1];
self.xv[1] = self.xv[2];
self.xv[2] = sample / 2.7f32;
self.yv[0] = self.yv[1];
self.yv[1] = self.yv[2];
self.yv[2] = (self.xv[2] - self.xv[0])
+ (-0.7169861741f32 * self.yv[0])
+ (1.4453653501f32 * self.yv[1]);
self.yv[2]
}
}

View File

@ -0,0 +1,141 @@
use anyhow::{anyhow, Result};
use ringbuffer::{ConstGenericRingBuffer, RingBufferExt, RingBufferWrite};
use std::{
sync::{Arc, Mutex},
thread::{self, JoinHandle},
time::Instant,
};
use pulse::sample::{Format, Spec};
use self::dsp::ZTransformFilter;
mod capture;
mod dsp;
const SAMPLE_RATE: usize = 5000;
const PULSE_UPDATES_PER_SECOND: usize = 50;
const BUFFER_SIZE: usize = SAMPLE_RATE / PULSE_UPDATES_PER_SECOND;
const POINTS_PER_SECOND: usize = SAMPLE_RATE / 200;
pub const MILLIS_PER_POINT: usize = 1000 / POINTS_PER_SECOND;
const POINT_MIN_COUNT: usize = 2 * POINTS_PER_SECOND;
const POINT_BUFFER_SIZE: usize = POINT_MIN_COUNT.next_power_of_two();
pub struct AudioCaptureThread {
join_handle: Option<JoinHandle<Result<()>>>,
shared_state: Arc<Mutex<SharedState>>,
}
impl AudioCaptureThread {
pub fn new() -> Self {
let shared_state = Arc::new(Mutex::new(SharedState::new()));
let join_handle = {
let shared_state = shared_state.clone();
Some(thread::spawn(move || audio_capture_thread(shared_state)))
};
Self {
shared_state,
join_handle,
}
}
pub fn stop(&mut self) -> Result<()> {
if self.join_handle.is_none() {
return Err(anyhow!(
"join_handle was none, was stop called multiple times?"
));
} else {
let handle = self.join_handle.take().unwrap();
handle.join().unwrap()
}
}
pub fn read_state(&self) -> (Vec<f32>, Instant) {
let state = self.shared_state.lock().unwrap();
(state.point_buf.to_vec(), state.last_update)
}
pub fn get_points(&self) -> Vec<f32> {
let state = self.shared_state.lock().unwrap();
state.point_buf.to_vec()
}
pub fn get_last_update(&self) -> Instant {
let state = self.shared_state.lock().unwrap();
state.last_update
}
}
struct SharedState {
running: bool,
point_buf: ConstGenericRingBuffer<f32, POINT_BUFFER_SIZE>,
last_update: Instant,
}
impl SharedState {
pub fn new() -> Self {
let mut point_buf: ConstGenericRingBuffer<f32, POINT_BUFFER_SIZE> = Default::default();
point_buf.fill_default();
Self {
running: true,
point_buf,
last_update: Instant::now(),
}
}
}
fn audio_capture_thread(state: Arc<Mutex<SharedState>>) -> Result<()> {
let spec = Spec {
format: Format::F32le,
rate: SAMPLE_RATE as u32,
channels: 1,
};
let reader = capture::get_audio_reader(&spec)?;
let mut buffer = [0u8; 4 * BUFFER_SIZE];
let mut bass_filter = dsp::BassFilter::default();
let mut envelope_filter = dsp::EnvelopeFilter::default();
let mut beat_filter = dsp::BeatFilter::default();
let mut j = 0;
loop {
{
if state.lock().unwrap().running == false {
break Ok(());
}
}
reader.read(&mut buffer)?;
for i in 0..BUFFER_SIZE {
let mut float_bytes = [0u8; 4];
float_bytes.copy_from_slice(&buffer[4 * i..4 * i + 4]);
j += 1;
let sample = f32::from_le_bytes(float_bytes);
let mut value = bass_filter.process(sample);
if value < 0f32 {
value = -value;
}
let envelope = envelope_filter.process(value);
if j == 200 {
let beat = beat_filter.process(envelope);
let mut state = state.lock().unwrap();
state.point_buf.push(beat);
state.last_update = Instant::now();
j = 0;
}
}
}
}

View File

@ -1,4 +1,7 @@
use std::time::{Duration, Instant}; use std::{
ops::RangeInclusive,
time::{Duration, Instant},
};
pub struct Metronome { pub struct Metronome {
taps: Vec<Instant>, taps: Vec<Instant>,
@ -34,14 +37,35 @@ impl Metronome {
} }
} }
pub fn current_beat_progress(&self) -> Option<f64> { pub fn get_beat_progress(&self, now: Instant) -> Option<f64> {
if self.beat_interval.is_none() { if self.beat_interval.is_none() {
return None; return None;
} }
let now = Instant::now();
let relative_millis = (now - *self.taps.last().unwrap()).as_millis(); let relative_millis = (now - *self.taps.last().unwrap()).as_millis();
Some(relative_millis as f64 / self.beat_interval.unwrap() as f64) Some(relative_millis as f64 / self.beat_interval.unwrap() as f64)
} }
pub fn get_beats(&self, from: Instant, to: Instant) -> Vec<(u64, Instant)> {
if self.beat_interval.is_none() {
return Vec::new();
}
let beat_interval = self.beat_interval.unwrap();
let last = *self.taps.last().unwrap();
let relative_from = (from - last).as_millis() as f64 / beat_interval as f64;
let relative_to = (to - last).as_millis() as f64 / beat_interval as f64;
(relative_from.ceil() as u64..=relative_to.floor() as u64)
.map(|beat_number| {
(
beat_number,
last + Duration::from_millis((beat_number as u128 * beat_interval) as u64),
)
})
.collect()
}
} }

View File

@ -1,19 +1,21 @@
mod audio;
mod metronome; mod metronome;
mod tracker;
use std::{cell::RefCell, time::Duration}; use std::{cell::RefCell, time::Duration};
use metronome::Metronome; use tracker::BeatTracker;
use neon::prelude::*; use neon::prelude::*;
type BoxedTracker = JsBox<RefCell<Metronome>>; type BoxedTracker = JsBox<RefCell<BeatTracker>>;
impl Finalize for Metronome {} impl Finalize for BeatTracker {}
pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult<JsObject> { pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult<JsObject> {
let obj = cx.empty_object(); let obj = cx.empty_object();
let value_obj = cx.empty_object(); let value_obj = cx.empty_object();
let boxed_tracker = cx.boxed(RefCell::new(Metronome::new(Duration::from_secs(2)))); let boxed_tracker = cx.boxed(RefCell::new(BeatTracker::new(Duration::from_secs(2))));
value_obj.set(&mut cx, "_rust_ptr", boxed_tracker)?; value_obj.set(&mut cx, "_rust_ptr", boxed_tracker)?;
let tap_function = JsFunction::new(&mut cx, tap)?; let tap_function = JsFunction::new(&mut cx, tap)?;
@ -22,6 +24,12 @@ pub fn get_beat_tracker(mut cx: FunctionContext) -> JsResult<JsObject> {
let get_progress_function = JsFunction::new(&mut cx, get_progress)?; let get_progress_function = JsFunction::new(&mut cx, get_progress)?;
value_obj.set(&mut cx, "getProgress", get_progress_function)?; value_obj.set(&mut cx, "getProgress", get_progress_function)?;
let close_function = JsFunction::new(&mut cx, close)?;
value_obj.set(&mut cx, "close", close_function)?;
let get_graph_points_function = JsFunction::new(&mut cx, get_graph_points)?;
value_obj.set(&mut cx, "getGraphPoints", get_graph_points_function)?;
obj.set(&mut cx, "value", value_obj)?; obj.set(&mut cx, "value", value_obj)?;
let success_string = cx.string("success".to_string()); let success_string = cx.string("success".to_string());
@ -68,3 +76,57 @@ pub fn get_progress(mut cx: FunctionContext) -> JsResult<JsObject> {
Ok(obj) Ok(obj)
} }
fn get_graph_points(mut cx: FunctionContext) -> JsResult<JsObject> {
let this = cx.this();
let boxed_tracker = this
.get(&mut cx, "_rust_ptr")?
.downcast_or_throw::<BoxedTracker, _>(&mut cx)?;
let obj = cx.empty_object();
let (bass_filtered, auto_correlated) = boxed_tracker.borrow().get_graph_points();
let arr_bf = cx.empty_array();
for (i, f) in bass_filtered.iter().enumerate() {
let num = cx.number(*f);
arr_bf.set(&mut cx, i as u32, num)?;
}
obj.set(&mut cx, "bassFiltered", arr_bf)?;
let arr_ac = cx.empty_array();
for (i, f) in auto_correlated.iter().enumerate() {
let num = cx.number(*f);
arr_ac.set(&mut cx, i as u32, num)?;
}
obj.set(&mut cx, "autoCorrelated", arr_ac)?;
Ok(obj)
}
pub fn close(mut cx: FunctionContext) -> JsResult<JsObject> {
let this = cx.this();
let boxed_tracker = this
.get(&mut cx, "_rust_ptr")?
.downcast_or_throw::<BoxedTracker, _>(&mut cx)?;
let obj = cx.empty_object();
let success_string;
match boxed_tracker.borrow_mut().stop() {
Ok(()) => {
success_string = cx.string("success");
}
Err(e) => {
success_string = cx.string("error");
let error_message = cx.string(e.to_string());
obj.set(&mut cx, "message", error_message)?;
}
}
obj.set(&mut cx, "type", success_string)?;
Ok(obj)
}

View File

@ -0,0 +1,167 @@
use anyhow::Result;
use num::Complex;
use rustfft::FftPlanner;
use std::time::{Duration, Instant};
use super::audio::{AudioCaptureThread, MILLIS_PER_POINT};
use super::metronome::Metronome;
pub struct TrackerConfig {
ac_threshold: f32,
zero_crossing_beat_delay: i64,
}
pub struct BeatTracker {
metronome: Metronome,
config: TrackerConfig,
audio_capture_thread: AudioCaptureThread,
}
impl BeatTracker {
pub fn new(metronome_timeout: Duration) -> Self {
Self {
metronome: Metronome::new(metronome_timeout),
config: TrackerConfig {
ac_threshold: 1000.0,
zero_crossing_beat_delay: 0,
},
audio_capture_thread: AudioCaptureThread::new(),
}
}
pub fn tap(&mut self) {
self.metronome.tap();
}
pub fn get_graph_points(&self) -> (Vec<f32>, Vec<f32>) {
let points = self.audio_capture_thread.get_points();
let mut autocorrelation = points
.iter()
.map(|point| Complex {
re: *point,
im: 0.0,
})
.collect::<Vec<_>>();
let n = autocorrelation.len();
for _ in 0..n {
autocorrelation.push(Complex { re: 0.0, im: 0.0 });
}
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(2 * n);
let ifft = planner.plan_fft_inverse(2 * n);
fft.process(&mut autocorrelation);
for f in autocorrelation.iter_mut() {
*f = Complex {
re: f.norm_sqr(),
im: 0.0,
};
}
ifft.process(&mut autocorrelation);
let mut autocorrelation: Vec<f32> = autocorrelation[..n].iter().map(|x| x.re).collect();
// remove first peak
let mut j = 0;
let i = loop {
j += 1;
if j == n {
break None;
}
if autocorrelation[j - 1] <= autocorrelation[j] {
break Some(j);
}
};
if let Some(i) = i {
let first_increase = autocorrelation[i];
for f in autocorrelation[..i].iter_mut() {
*f = first_increase;
}
}
(points, autocorrelation)
}
pub fn current_beat_progress(&self) -> Option<f64> {
let (points, autocorrelation) = self.get_graph_points();
if autocorrelation
.iter()
.reduce(|a, b| if a > b { a } else { b })
.unwrap()
< &self.config.ac_threshold
{
return None;
}
let mut j = 0;
let i = loop {
j += 1;
if j == autocorrelation.len() {
break None;
}
if autocorrelation[j - 1] > autocorrelation[j] {
break Some(j);
}
};
if i.is_none() {
return None;
}
let period_length = i.unwrap();
// println!("predicted period length: {}", period_length);
// find zero crossing
let mut crossing = None;
for i in 1..points.len() {
if points[i-1].signum() < 0.0 && points[i].signum() > 0.0 {
crossing = Some(i);
}
}
if crossing.is_none() {
return None;
}
let crossing = crossing.unwrap();
// println!("index of last positive zero crossing: {}", crossing);
let last_update_timestamp = self.audio_capture_thread.get_last_update();
let now = Instant::now();
let mut dt = (now - last_update_timestamp).as_millis() as i64;
dt += ((points.len() - crossing) * MILLIS_PER_POINT) as i64;
dt += self.config.zero_crossing_beat_delay;
let dt = dt as f64;
let period_millis = (period_length * MILLIS_PER_POINT) as f64;
let relative_time = (dt % period_millis) / period_millis;
println!("relative time: {}", relative_time);
Some(relative_time)
}
pub fn stop(&mut self) -> Result<()> {
self.audio_capture_thread.stop()
}
}