SOUSAphone
A lightweight two-stage pipeline that classifies all 40 PAS International Drum Rudiments from audio. Stage 1 infers rich per-stroke features from raw onset detection output; Stage 2 classifies the rudiment from those features.
Trained on the SOUSA dataset. Try it live at the demo space.
Pipeline Overview
Audio
β
βΌ
Onset Detection (librosa)
β onset_times, onset_strengths, tempo_bpm
βΌ
Feature Inference Model (3 β 12 dims)
β Predicts stroke type, sticking, grace notes, etc.
βΌ
Onset Transformer (12 dims β 40 classes)
β Classifies rudiment
βΌ
Predicted Rudiment + Confidence
Both models are Transformer encoders sharing the same architecture template (~120K parameters each, <500 KB per weight file).
Models
Feature Inference Model
Bridges the gap between raw onset detection (timing + strength) and the classifier's rich feature space (stroke types, sticking, grace notes, etc.).
| Component | Details |
|---|---|
| Input projection | Linear(3 β 64) |
| Positional encoding | Learnable Embedding(256, 64) |
| Encoder | 3x TransformerEncoderLayer(d=64, nhead=4, ffn=128) |
| Output projection | LayerNorm(64) β Linear(64, 12) |
Input (3 dimensions per stroke):
| Idx | Feature | Description |
|---|---|---|
| 0 | ioi_ms |
Inter-onset interval in milliseconds |
| 1 | onset_strength |
Onset strength (0-1 normalized) |
| 2 | tempo_bpm |
Estimated tempo in BPM |
Output (12 dimensions per stroke):
| Idx | Feature | Type | Description |
|---|---|---|---|
| 0 | norm_ioi |
continuous | Inter-onset interval / beat duration |
| 1 | norm_velocity |
continuous | MIDI velocity / 127 |
| 2 | is_grace |
binary | Grace note flag |
| 3 | is_accent |
binary | Accent flag |
| 4 | is_tap |
binary | Tap stroke type |
| 5 | is_diddle |
binary | Diddle stroke type |
| 6 | hand_R |
binary | Right hand = 1, Left = 0 |
| 7 | diddle_pos |
continuous | Position within diddle pair |
| 8 | norm_flam_spacing |
continuous | Flam spacing / beat duration |
| 9 | position_in_beat |
continuous | Metric position within beat |
| 10 | is_buzz |
binary | Buzz stroke type |
| 11 | norm_buzz_count |
continuous | Buzz sub-stroke count / 8 |
Binary outputs (indices 2, 3, 4, 5, 6, 10) should be post-processed with sigmoid.
Files: feature_inference_model.bin, feature_inference_config.json
Onset Transformer (Classifier)
Classifies a sequence of 12-dim per-stroke features into one of 40 PAS rudiments.
| Component | Details |
|---|---|
| Input projection | Linear(12 β 64) |
| Positional encoding | Learnable Embedding(256, 64) |
| Encoder | 3x TransformerEncoderLayer(d=64, nhead=4, ffn=128) |
| Pooling | Mean pooling with attention mask |
| Classifier | LayerNorm(64) β Linear(64, 40) |
| Total parameters | 120,360 |
Files: pytorch_model.bin, onset_transformer_config.json
Training
- Dataset: SOUSA β 100,000 synthetic drum rudiment performances across 100 player profiles
- Split: 68 train / 13 val / 19 test profiles (profile-based to prevent data leakage)
- Optimizer: AdamW (lr=1e-3, weight_decay=0.01)
- Scheduler: Cosine with warmup (10% warmup ratio)
- Label smoothing: 0.1
- Precision: 16-bit mixed
- Result: 100% test accuracy (F1 = 1.0 macro across all 40 classes)
Usage
Full Pipeline (audio β rudiment)
import librosa
from sousa.inference.pipeline import RudimentPipeline
# Load pipeline with both models
pipeline = RudimentPipeline(
feature_model_path="feature_inference_model.bin",
classifier_model_path="pytorch_model.bin",
)
# Load audio
audio, sr = librosa.load("rudiment.wav", sr=22050)
# Predict
result = pipeline.predict(audio, sr=sr)
print(f"Rudiment: {result['predicted_rudiment']}")
print(f"Confidence: {result['confidence']:.1%}")
print(f"Tempo: {result['tempo_bpm']:.0f} BPM")
print(f"Top 5: {result['top5']}")
The pipeline returns:
| Key | Type | Description |
|---|---|---|
predicted_rudiment |
str | Top-1 rudiment name |
confidence |
float | Top-1 probability |
top5 |
list[dict] | Top 5 predictions with confidence |
onset_times |
ndarray | Detected onset times (seconds) |
onset_strengths |
ndarray | Onset strengths (0-1) |
tempo_bpm |
float | Estimated tempo |
predicted_features |
ndarray | Inferred 12-dim features per stroke |
attention_mask |
ndarray | Stroke attention mask |
Classifier Only (pre-extracted features)
import torch
import json
from sousa.models.onset_transformer import OnsetTransformerModel
# Load model
model = OnsetTransformerModel(
num_classes=40, feature_dim=12, d_model=64,
nhead=4, num_layers=3, dim_feedforward=128,
dropout=0.0, max_seq_len=256,
)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Load label mapping
with open("id2label.json") as f:
id2label = json.load(f)
# Classify a stroke sequence (batch=1, seq_len=N, features=12)
onset_features = torch.randn(1, 50, 12) # your encoded strokes
attention_mask = torch.ones(1, 50) # 1 for real strokes, 0 for padding
with torch.no_grad():
logits = model(onset_features, attention_mask)
pred = logits.argmax(dim=-1).item()
print(f"Predicted rudiment: {id2label[str(pred)]}")
Repository Contents
| File | Description |
|---|---|
pytorch_model.bin |
Onset Transformer classifier weights (495 KB) |
feature_inference_model.bin |
Feature Inference Model weights (486 KB) |
onset_transformer_config.json |
Classifier architecture config |
feature_inference_config.json |
Feature inference architecture config |
config.json |
HuggingFace model config |
id2label.json |
Class ID β rudiment name mapping |
Supported Rudiments (40 classes)
All 40 PAS International Drum Rudiments:
| ID | Rudiment | ID | Rudiment |
|---|---|---|---|
| 0 | double-drag-tap | 20 | nine-stroke-roll |
| 1 | double-paradiddle | 21 | pataflafla |
| 2 | double-ratamacue | 22 | seven-stroke-roll |
| 3 | double-stroke-open-roll | 23 | seventeen-stroke-roll |
| 4 | drag | 24 | single-drag-tap |
| 5 | drag-paradiddle-1 | 25 | single-dragadiddle |
| 6 | drag-paradiddle-2 | 26 | single-flammed-mill |
| 7 | eleven-stroke-roll | 27 | single-paradiddle |
| 8 | fifteen-stroke-roll | 28 | single-paradiddle-diddle |
| 9 | five-stroke-roll | 29 | single-ratamacue |
| 10 | flam | 30 | single-stroke-four |
| 11 | flam-accent | 31 | single-stroke-roll |
| 12 | flam-drag | 32 | single-stroke-seven |
| 13 | flam-paradiddle | 33 | six-stroke-roll |
| 14 | flam-paradiddle-diddle | 34 | swiss-army-triplet |
| 15 | flam-tap | 35 | ten-stroke-roll |
| 16 | flamacue | 36 | thirteen-stroke-roll |
| 17 | inverted-flam-tap | 37 | triple-paradiddle |
| 18 | lesson-25 | 38 | triple-ratamacue |
| 19 | multiple-bounce-roll | 39 | triple-stroke-roll |
Limitations
- The feature inference model predicts stroke-level attributes (sticking, grace notes, etc.) from timing and strength alone. Accuracy of these inferred features on real-world audio depends on onset detection quality.
- Trained entirely on synthetic SOUSA data. Real-world generalization is untested.
- Maximum sequence length of 256 strokes per input.
Citation
@misc{sousaphone,
title={SOUSAphone: Onset Transformer for Drum Rudiment Classification},
author={Zak Keown},
year={2026},
url={https://huggingface.co/zkeown/sousaphone}
}
- Downloads last month
- 59
Dataset used to train zkeown/sousaphone
Space using zkeown/sousaphone 1
Evaluation results
- Test Accuracyself-reported1.000
- Test F1 (macro)self-reported1.000