SOUSAphone

A lightweight two-stage pipeline that classifies all 40 PAS International Drum Rudiments from audio. Stage 1 infers rich per-stroke features from raw onset detection output; Stage 2 classifies the rudiment from those features.

Trained on the SOUSA dataset. Try it live at the demo space.

Pipeline Overview

Audio
  β”‚
  β–Ό
Onset Detection (librosa)
  β”‚  onset_times, onset_strengths, tempo_bpm
  β–Ό
Feature Inference Model (3 β†’ 12 dims)
  β”‚  Predicts stroke type, sticking, grace notes, etc.
  β–Ό
Onset Transformer (12 dims β†’ 40 classes)
  β”‚  Classifies rudiment
  β–Ό
Predicted Rudiment + Confidence

Both models are Transformer encoders sharing the same architecture template (~120K parameters each, <500 KB per weight file).

Models

Feature Inference Model

Bridges the gap between raw onset detection (timing + strength) and the classifier's rich feature space (stroke types, sticking, grace notes, etc.).

Component Details
Input projection Linear(3 β†’ 64)
Positional encoding Learnable Embedding(256, 64)
Encoder 3x TransformerEncoderLayer(d=64, nhead=4, ffn=128)
Output projection LayerNorm(64) β†’ Linear(64, 12)

Input (3 dimensions per stroke):

Idx Feature Description
0 ioi_ms Inter-onset interval in milliseconds
1 onset_strength Onset strength (0-1 normalized)
2 tempo_bpm Estimated tempo in BPM

Output (12 dimensions per stroke):

Idx Feature Type Description
0 norm_ioi continuous Inter-onset interval / beat duration
1 norm_velocity continuous MIDI velocity / 127
2 is_grace binary Grace note flag
3 is_accent binary Accent flag
4 is_tap binary Tap stroke type
5 is_diddle binary Diddle stroke type
6 hand_R binary Right hand = 1, Left = 0
7 diddle_pos continuous Position within diddle pair
8 norm_flam_spacing continuous Flam spacing / beat duration
9 position_in_beat continuous Metric position within beat
10 is_buzz binary Buzz stroke type
11 norm_buzz_count continuous Buzz sub-stroke count / 8

Binary outputs (indices 2, 3, 4, 5, 6, 10) should be post-processed with sigmoid.

Files: feature_inference_model.bin, feature_inference_config.json

Onset Transformer (Classifier)

Classifies a sequence of 12-dim per-stroke features into one of 40 PAS rudiments.

Component Details
Input projection Linear(12 β†’ 64)
Positional encoding Learnable Embedding(256, 64)
Encoder 3x TransformerEncoderLayer(d=64, nhead=4, ffn=128)
Pooling Mean pooling with attention mask
Classifier LayerNorm(64) β†’ Linear(64, 40)
Total parameters 120,360

Files: pytorch_model.bin, onset_transformer_config.json

Training

  • Dataset: SOUSA β€” 100,000 synthetic drum rudiment performances across 100 player profiles
  • Split: 68 train / 13 val / 19 test profiles (profile-based to prevent data leakage)
  • Optimizer: AdamW (lr=1e-3, weight_decay=0.01)
  • Scheduler: Cosine with warmup (10% warmup ratio)
  • Label smoothing: 0.1
  • Precision: 16-bit mixed
  • Result: 100% test accuracy (F1 = 1.0 macro across all 40 classes)

Usage

Full Pipeline (audio β†’ rudiment)

import librosa
from sousa.inference.pipeline import RudimentPipeline

# Load pipeline with both models
pipeline = RudimentPipeline(
    feature_model_path="feature_inference_model.bin",
    classifier_model_path="pytorch_model.bin",
)

# Load audio
audio, sr = librosa.load("rudiment.wav", sr=22050)

# Predict
result = pipeline.predict(audio, sr=sr)
print(f"Rudiment: {result['predicted_rudiment']}")
print(f"Confidence: {result['confidence']:.1%}")
print(f"Tempo: {result['tempo_bpm']:.0f} BPM")
print(f"Top 5: {result['top5']}")

The pipeline returns:

Key Type Description
predicted_rudiment str Top-1 rudiment name
confidence float Top-1 probability
top5 list[dict] Top 5 predictions with confidence
onset_times ndarray Detected onset times (seconds)
onset_strengths ndarray Onset strengths (0-1)
tempo_bpm float Estimated tempo
predicted_features ndarray Inferred 12-dim features per stroke
attention_mask ndarray Stroke attention mask

Classifier Only (pre-extracted features)

import torch
import json
from sousa.models.onset_transformer import OnsetTransformerModel

# Load model
model = OnsetTransformerModel(
    num_classes=40, feature_dim=12, d_model=64,
    nhead=4, num_layers=3, dim_feedforward=128,
    dropout=0.0, max_seq_len=256,
)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

# Load label mapping
with open("id2label.json") as f:
    id2label = json.load(f)

# Classify a stroke sequence (batch=1, seq_len=N, features=12)
onset_features = torch.randn(1, 50, 12)  # your encoded strokes
attention_mask = torch.ones(1, 50)        # 1 for real strokes, 0 for padding

with torch.no_grad():
    logits = model(onset_features, attention_mask)
    pred = logits.argmax(dim=-1).item()
    print(f"Predicted rudiment: {id2label[str(pred)]}")

Repository Contents

File Description
pytorch_model.bin Onset Transformer classifier weights (495 KB)
feature_inference_model.bin Feature Inference Model weights (486 KB)
onset_transformer_config.json Classifier architecture config
feature_inference_config.json Feature inference architecture config
config.json HuggingFace model config
id2label.json Class ID β†’ rudiment name mapping

Supported Rudiments (40 classes)

All 40 PAS International Drum Rudiments:

ID Rudiment ID Rudiment
0 double-drag-tap 20 nine-stroke-roll
1 double-paradiddle 21 pataflafla
2 double-ratamacue 22 seven-stroke-roll
3 double-stroke-open-roll 23 seventeen-stroke-roll
4 drag 24 single-drag-tap
5 drag-paradiddle-1 25 single-dragadiddle
6 drag-paradiddle-2 26 single-flammed-mill
7 eleven-stroke-roll 27 single-paradiddle
8 fifteen-stroke-roll 28 single-paradiddle-diddle
9 five-stroke-roll 29 single-ratamacue
10 flam 30 single-stroke-four
11 flam-accent 31 single-stroke-roll
12 flam-drag 32 single-stroke-seven
13 flam-paradiddle 33 six-stroke-roll
14 flam-paradiddle-diddle 34 swiss-army-triplet
15 flam-tap 35 ten-stroke-roll
16 flamacue 36 thirteen-stroke-roll
17 inverted-flam-tap 37 triple-paradiddle
18 lesson-25 38 triple-ratamacue
19 multiple-bounce-roll 39 triple-stroke-roll

Limitations

  • The feature inference model predicts stroke-level attributes (sticking, grace notes, etc.) from timing and strength alone. Accuracy of these inferred features on real-world audio depends on onset detection quality.
  • Trained entirely on synthetic SOUSA data. Real-world generalization is untested.
  • Maximum sequence length of 256 strokes per input.

Citation

@misc{sousaphone,
  title={SOUSAphone: Onset Transformer for Drum Rudiment Classification},
  author={Zak Keown},
  year={2026},
  url={https://huggingface.co/zkeown/sousaphone}
}
Downloads last month
59
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train zkeown/sousaphone

Space using zkeown/sousaphone 1

Evaluation results