AmrYassinIsFree commited on
Commit
db0da0a
·
1 Parent(s): a9bc1f8

replace matplot with plotly, add more evals, UI re-org

Browse files
Files changed (7) hide show
  1. README.md +70 -86
  2. app.py +483 -159
  3. corpus.py +6 -2
  4. dataset_config.py +3 -1
  5. evals/llm_judge.py +194 -0
  6. evals/quality.py +59 -20
  7. requirements.txt +1 -1
README.md CHANGED
@@ -12,22 +12,18 @@ license: mit
12
 
13
  # embedding-bench
14
 
15
- Compare text embedding models across retrieval quality, inference speed, and memory footprint. Everything runs locally no external API calls.
16
 
17
- ## Models
18
 
19
- | Key | Model | Backend | Role |
20
- |-----|-------|---------|------|
21
- | `mpnet` | `sentence-transformers/all-mpnet-base-v2` | sbert | Baseline |
22
- | `bge-small` | `BAAI/bge-small-en-v1.5` | sbert | |
23
- | `bge-small-fe` | `BAAI/bge-small-en-v1.5` | fastembed | |
24
- | `all-minilm-fe` | `sentence-transformers/all-MiniLM-L6-v2` | fastembed | |
25
-
26
- Three backends are supported:
27
-
28
- - **sbert** — [sentence-transformers](https://www.sbert.net/) (PyTorch). Default.
29
- - **fastembed** — [qdrant/fastembed](https://github.com/qdrant/fastembed) (ONNX Runtime). Lighter and often faster.
30
- - **gguf** — [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) for quantised GGUF models.
31
 
32
  ## Setup
33
 
@@ -37,9 +33,40 @@ source .venv/bin/activate
37
  pip install -r requirements.txt
38
  ```
39
 
40
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- ### Basic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  ```bash
45
  # Full benchmark (quality + speed + memory)
@@ -48,78 +75,39 @@ python bench.py
48
  # Specific models
49
  python bench.py --models mpnet bge-small
50
 
51
- # Compare the same model across backends
52
  python bench.py --models bge-small bge-small-fe
53
 
54
  # Skip expensive evals
55
  python bench.py --skip-quality
56
  python bench.py --skip-memory
57
 
58
- # Tune corpus size and batch size
59
- python bench.py --corpus-size 500 --batch-size 32 --num-runs 5
60
- ```
61
-
62
- ### Datasets
63
-
64
- By default, quality is evaluated on the STS Benchmark. You can evaluate on multiple HuggingFace datasets using built-in presets:
65
-
66
- | Preset | HF Dataset | Type | Pairs |
67
- |--------|-----------|------|-------|
68
- | `sts` | `mteb/stsbenchmark-sts` | Scored (Spearman) | 1,379 |
69
- | `natural-questions` | `sentence-transformers/natural-questions` | Retrieval (MRR/Recall) | 100,231 |
70
- | `msmarco` | `sentence-transformers/msmarco-bm25` | Retrieval | 503,000 |
71
- | `squad` | `sentence-transformers/squad` | Retrieval | 87,599 |
72
- | `trivia-qa` | `sentence-transformers/trivia-qa` | Retrieval | 73,346 |
73
- | `gooaq` | `sentence-transformers/gooaq` | Retrieval | 3,012,496 |
74
- | `hotpotqa` | `sentence-transformers/hotpotqa` | Retrieval | 84,500 |
75
-
76
- ```bash
77
- # Evaluate on multiple datasets
78
  python bench.py --models mpnet bge-small \
79
  --datasets sts natural-questions squad \
80
- --skip-speed --skip-memory
81
-
82
- # Limit pairs for large datasets
83
- python bench.py --datasets msmarco gooaq --max-pairs 1000
84
 
85
- # Use a custom HF dataset (overrides --datasets)
86
  python bench.py --dataset my-org/my-pairs \
87
  --query-col query --passage-col passage --score-col none
88
- ```
89
-
90
- Scored datasets (with `--score-col`) report **Spearman correlation**. Pair-only datasets (`--score-col none`) report **MRR**, **Recall@1**, **Recall@5**, and **Recall@10**.
91
-
92
- ### Export results
93
-
94
- ```bash
95
- # Export to CSV
96
- python bench.py --csv results.csv
97
-
98
- # Save charts as PNG
99
- python bench.py --charts ./results
100
 
101
- # Both
102
- python bench.py --models mpnet bge-small \
103
- --datasets sts squad natural-questions \
104
- --max-pairs 1000 \
105
- --csv results.csv --charts ./results
106
  ```
107
 
108
- Charts generated:
109
- - `quality_<dataset>.png` — Spearman bar chart (scored) or grouped MRR/Recall bars (retrieval)
110
- - `speed.png` — sentences/second comparison
111
- - `memory.png` — peak memory usage comparison
112
-
113
- ## Metrics
114
 
115
- | Dimension | Metric | Method |
116
- |-----------|--------|--------|
117
- | Quality (scored) | Spearman rho | Cosine similarity vs gold scores |
118
- | Quality (pairs) | MRR, Recall@k | Retrieval ranking of positive passages |
119
- | Speed | Median encode time | Wall-clock over N runs with warmup |
120
- | Memory | Peak RSS delta | Isolated subprocess via `psutil` |
 
 
 
121
 
122
- ## CLI reference
123
 
124
  ```
125
  --models Models to benchmark (default: all)
@@ -144,36 +132,32 @@ Charts generated:
144
 
145
  ## Adding a model
146
 
147
- Edit `models.py` and add an entry to `REGISTRY`:
 
 
148
 
149
  ```python
150
- # sentence-transformers backend (default)
151
  "e5-small": ModelConfig(
152
  name="e5-small-v2",
153
  model_id="intfloat/e5-small-v2",
154
  ),
155
-
156
- # fastembed backend
157
- "e5-small-fe": ModelConfig(
158
- name="e5-small-v2 (fastembed)",
159
- model_id="intfloat/e5-small-v2",
160
- backend="fastembed",
161
- ),
162
  ```
163
 
164
  ## Project structure
165
 
166
  ```
167
  embedding-bench/
 
168
  ├── bench.py # CLI entry point
169
- ├── models.py # Model registry
170
- ├── wrapper.py # Backend wrappers (sbert, fastembed, gguf)
171
  ├── corpus.py # Sentence corpus builder
172
  ├── dataset_config.py # Dataset presets and configuration
173
- ├── report.py # Table formatting, CSV export, charts
174
  ├── evals/
175
- │ ├── quality.py # STS + retrieval evaluation
176
  │ ├── speed.py # Latency measurement
177
- ── memory.py # Memory measurement
 
178
  └── requirements.txt
179
  ```
 
12
 
13
  # embedding-bench
14
 
15
+ Compare text embedding models on quality, speed, and memory. Includes a Streamlit web UI and a CLI.
16
 
17
+ ## Features
18
 
19
+ - **40+ pre-configured models** sentence-transformers, BGE, E5, GTE, Nomic, Jina, Arctic, and more
20
+ - **4 backends** — sbert (PyTorch), fastembed (ONNX), gguf (llama-cpp), libembedding
21
+ - **7 built-in datasets** STS Benchmark, Natural Questions, MS MARCO, SQuAD, TriviaQA, GooAQ, HotpotQA
22
+ - **Custom datasets** — upload your own CSV/TSV or load any HuggingFace dataset
23
+ - **Custom models** add any HuggingFace embedding model from the UI
24
+ - **11 retrieval metrics** MRR, MAP@k, NDCG@k, Precision@k, Recall@k (all configurable)
25
+ - **LLM as a Judge** — use OpenAI or Anthropic to rate retrieval relevance
26
+ - **Interactive charts** — Plotly-powered, with hover, zoom, and PNG export
 
 
 
 
27
 
28
  ## Setup
29
 
 
33
  pip install -r requirements.txt
34
  ```
35
 
36
+ ## Web UI
37
+
38
+ ```bash
39
+ streamlit run app.py
40
+ ```
41
+
42
+ The sidebar has three sections:
43
+
44
+ 1. **Models** — select from the registry or add a custom HuggingFace model
45
+ 2. **Datasets** — pick built-in presets, upload a CSV/TSV, or add any HuggingFace dataset
46
+ 3. **Evaluation** — configure metrics, speed/memory benchmarks, LLM judge, and max pairs
47
+
48
+ ### Custom datasets
49
+
50
+ You can add datasets two ways from the sidebar:
51
 
52
+ - **Upload file** — CSV or TSV (max 50 MB, 50k rows) with a query column and a passage column. Optionally include a numeric score column for Spearman correlation; otherwise retrieval metrics (MRR, Recall@k, etc.) are used.
53
+ - **HuggingFace Hub** — provide the dataset ID (e.g. `mteb/stsbenchmark-sts`), config, split, and column names. The dataset is validated on add.
54
+
55
+ ### LLM as a Judge
56
+
57
+ Enable in the Evaluation section. Provide your OpenAI or Anthropic API key. For each sampled query, the top-5 retrieved passages are rated for relevance (1–5) by the LLM. Reports judge_avg@1, judge_avg@5, and judge_nDCG@5.
58
+
59
+ ### Metrics
60
+
61
+ | Dimension | Metrics | Method |
62
+ |-----------|---------|--------|
63
+ | Quality (scored) | Spearman | Cosine similarity vs gold scores |
64
+ | Quality (pairs) | MRR, MAP@5/10, NDCG@5/10, Precision@1/5/10, Recall@1/5/10 | Retrieval ranking of positive passages |
65
+ | LLM Judge | Avg@1, Avg@5, nDCG@5 | LLM relevance ratings on retrieved passages |
66
+ | Speed | Median encode time, sent/s | Wall-clock over N runs with warmup |
67
+ | Memory | Peak RSS delta (MB) | Isolated subprocess via `psutil` |
68
+
69
+ ## CLI
70
 
71
  ```bash
72
  # Full benchmark (quality + speed + memory)
 
75
  # Specific models
76
  python bench.py --models mpnet bge-small
77
 
78
+ # Compare backends
79
  python bench.py --models bge-small bge-small-fe
80
 
81
  # Skip expensive evals
82
  python bench.py --skip-quality
83
  python bench.py --skip-memory
84
 
85
+ # Multiple datasets with pair limit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  python bench.py --models mpnet bge-small \
87
  --datasets sts natural-questions squad \
88
+ --max-pairs 1000 --skip-speed --skip-memory
 
 
 
89
 
90
+ # Custom HF dataset
91
  python bench.py --dataset my-org/my-pairs \
92
  --query-col query --passage-col passage --score-col none
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # Export
95
+ python bench.py --csv results.csv --charts ./results
 
 
 
96
  ```
97
 
98
+ ### Built-in dataset presets
 
 
 
 
 
99
 
100
+ | Preset | HF Dataset | Type |
101
+ |--------|-----------|------|
102
+ | `sts` | `mteb/stsbenchmark-sts` | Scored (Spearman) |
103
+ | `natural-questions` | `sentence-transformers/natural-questions` | Retrieval |
104
+ | `msmarco` | `sentence-transformers/msmarco-bm25` | Retrieval |
105
+ | `squad` | `sentence-transformers/squad` | Retrieval |
106
+ | `trivia-qa` | `sentence-transformers/trivia-qa` | Retrieval |
107
+ | `gooaq` | `sentence-transformers/gooaq` | Retrieval |
108
+ | `hotpotqa` | `sentence-transformers/hotpotqa` | Retrieval |
109
 
110
+ ### CLI flags
111
 
112
  ```
113
  --models Models to benchmark (default: all)
 
132
 
133
  ## Adding a model
134
 
135
+ From the web UI, click **Add Custom Model** in the sidebar — just provide a display name and a HuggingFace model ID.
136
+
137
+ Or edit `models.py` directly:
138
 
139
  ```python
 
140
  "e5-small": ModelConfig(
141
  name="e5-small-v2",
142
  model_id="intfloat/e5-small-v2",
143
  ),
 
 
 
 
 
 
 
144
  ```
145
 
146
  ## Project structure
147
 
148
  ```
149
  embedding-bench/
150
+ ├── app.py # Streamlit web UI
151
  ├── bench.py # CLI entry point
152
+ ├── models.py # Model registry (40+ models)
153
+ ├── wrapper.py # Backend wrappers (sbert, fastembed, gguf, libembedding)
154
  ├── corpus.py # Sentence corpus builder
155
  ├── dataset_config.py # Dataset presets and configuration
156
+ ├── report.py # Table formatting, CSV export, charts (CLI)
157
  ├── evals/
158
+ │ ├── quality.py # Quality evaluation (Spearman + retrieval metrics)
159
  │ ├── speed.py # Latency measurement
160
+ ── memory.py # Memory measurement
161
+ │ └── llm_judge.py # LLM-as-a-Judge evaluation
162
  └── requirements.txt
163
  ```
app.py CHANGED
@@ -2,17 +2,19 @@ from __future__ import annotations
2
 
3
  import io
4
  import csv
 
5
  import time
6
 
7
- import matplotlib.pyplot as plt
8
  import numpy as np
 
 
9
  import streamlit as st
10
 
11
  from datasets import load_dataset
12
 
13
  from corpus import build_corpus
14
  from dataset_config import DATASET_PRESETS, DatasetConfig
15
- from evals.quality import evaluate_quality
16
  from evals.speed import evaluate_speed
17
  from models import (
18
  REGISTRY,
@@ -109,11 +111,22 @@ with col_badge:
109
 
110
  st.markdown("<hr class='section-divider'>", unsafe_allow_html=True)
111
 
 
 
 
 
 
 
 
 
 
 
112
  # ---------------------------------------------------------------------------
113
  # Sidebar — configuration
114
  # ---------------------------------------------------------------------------
115
  st.sidebar.markdown("### ⚙️ Configuration")
116
 
 
117
  st.sidebar.markdown("**Models**")
118
  available_models = list(REGISTRY.keys())
119
  selected_models = st.sidebar.multiselect(
@@ -125,22 +138,34 @@ selected_models = st.sidebar.multiselect(
125
 
126
  with st.sidebar.expander("➕ Add Custom Model"):
127
  with st.form("add_model_form", clear_on_submit=True):
128
- new_key = st.text_input("Registry key", placeholder="my-model")
129
  new_name = st.text_input("Display name", placeholder="My Custom Model")
130
  new_model_id = st.text_input("HuggingFace model ID", placeholder="org/model-name")
131
  new_backend = st.selectbox("Backend", sorted(VALID_BACKENDS))
132
  new_gguf_file = st.text_input(
133
- "GGUF filename (gguf backend only)", value="", placeholder="model.gguf"
 
134
  )
135
- new_is_baseline = st.checkbox("Mark as baseline", value=False)
136
- new_persist = st.checkbox("Save to disk", value=False,
137
- help="Persist to custom_models.json so it loads next session")
 
138
  submitted = st.form_submit_button("Add Model", use_container_width=True)
139
  if submitted:
140
- if not new_key or not new_name or not new_model_id:
141
- st.sidebar.error("Key, name, and model ID are required.")
142
- elif new_backend == "gguf" and not new_gguf_file:
143
- st.sidebar.error("GGUF filename is required for gguf backend.")
 
 
 
 
 
 
 
 
 
 
 
144
  else:
145
  cfg = ModelConfig(
146
  name=new_name,
@@ -157,58 +182,280 @@ with st.sidebar.expander("➕ Add Custom Model"):
157
  except ValueError as e:
158
  st.sidebar.error(str(e))
159
 
 
160
  st.sidebar.markdown("**Datasets**")
161
- available_datasets = list(DATASET_PRESETS.keys())
 
 
 
 
 
162
  selected_datasets = st.sidebar.multiselect(
163
- "Select dataset presets",
164
  available_datasets,
165
- default=["sts"],
166
  label_visibility="collapsed",
167
  )
168
 
169
- max_pairs = st.sidebar.number_input(
170
- "Max pairs per dataset",
171
- min_value=100,
172
- max_value=50000,
173
- value=1000,
174
- step=100,
175
- help="Limits the number of pairs evaluated. Keep low for large datasets.",
176
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- st.sidebar.markdown("---")
179
- st.sidebar.markdown("**Speed & Memory**")
180
- run_speed = st.sidebar.checkbox("Speed benchmark", value=False)
181
- run_memory = st.sidebar.checkbox("Memory benchmark", value=False)
182
-
183
- corpus_size = 500
184
- num_runs = 3
185
- batch_size = 64
186
- if run_speed or run_memory:
187
- corpus_size = st.sidebar.number_input("Corpus size", 100, 10000, 500, step=100)
188
- batch_size = st.sidebar.number_input("Batch size", 8, 512, 64, step=8)
189
- if run_speed:
190
- num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
191
-
192
- st.sidebar.markdown("---")
193
- st.sidebar.markdown("**Cache**")
194
- _cache_c1, _cache_c2 = st.sidebar.columns(2)
195
- with _cache_c1:
196
- if st.button("🗑️ Clear All", use_container_width=True,
197
- help="Clear cached models, datasets, and results"):
198
- st.cache_resource.clear()
199
- st.cache_data.clear()
200
- for key in list(st.session_state.keys()):
201
- del st.session_state[key]
202
- st.rerun()
203
- with _cache_c2:
204
- if st.button("🔄 Results", use_container_width=True,
205
- help="Clear eval results but keep models loaded"):
206
- st.cache_data.clear()
207
- for key in ["results", "selected_datasets"]:
208
- st.session_state.pop(key, None)
209
- st.rerun()
210
-
211
- st.sidebar.markdown("---")
212
 
213
  # ---------------------------------------------------------------------------
214
  # Cached functions
@@ -239,8 +486,9 @@ def cached_evaluate_quality(
239
  score_col: str | None,
240
  score_scale: float,
241
  max_pairs: int | None,
 
242
  ) -> dict[str, float]:
243
- """Cache quality results keyed by (model, dataset, max_pairs).
244
 
245
  The _model arg is excluded from the hash (underscore prefix).
246
  model_key is used as a hashable stand-in.
@@ -250,7 +498,10 @@ def cached_evaluate_quality(
250
  query_col=query_col, passage_col=passage_col,
251
  score_col=score_col, score_scale=score_scale,
252
  )
253
- return evaluate_quality(_model, ds_cfg, max_pairs=max_pairs)
 
 
 
254
 
255
 
256
  @st.cache_data(show_spinner="Building corpus...", ttl=3600)
@@ -274,6 +525,9 @@ def flatten_result(r: dict) -> dict:
274
  for ds_key, metrics in r.get("quality", {}).items():
275
  for metric_name, value in metrics.items():
276
  flat[f"{ds_key}/{metric_name}"] = value
 
 
 
277
  speed = r.get("speed")
278
  if speed:
279
  flat["Speed (sent/s)"] = speed["sentences_per_second"]
@@ -311,23 +565,19 @@ def render_metric_card(label: str, value: str, sub: str = "", css_class: str = "
311
 
312
 
313
  # ---------------------------------------------------------------------------
314
- # Chart style helper
315
  # ---------------------------------------------------------------------------
316
  CHART_BG = "#0E1117"
317
- CHART_TEXT = "#CCCCCC"
318
-
319
- def style_chart(fig, ax):
320
- """Apply dark theme to a matplotlib chart."""
321
- fig.patch.set_facecolor(CHART_BG)
322
- ax.set_facecolor(CHART_BG)
323
- ax.spines["top"].set_visible(False)
324
- ax.spines["right"].set_visible(False)
325
- ax.spines["left"].set_color("#444")
326
- ax.spines["bottom"].set_color("#444")
327
- ax.tick_params(colors=CHART_TEXT, labelsize=7)
328
- ax.yaxis.label.set_color(CHART_TEXT)
329
- ax.xaxis.label.set_color(CHART_TEXT)
330
- ax.title.set_color("#FAFAFA")
331
 
332
 
333
  # ---------------------------------------------------------------------------
@@ -341,13 +591,20 @@ if not selected_datasets:
341
  st.warning("Select at least one dataset from the sidebar.")
342
  st.stop()
343
 
344
- run_btn = st.sidebar.button("🚀 Run Benchmark", type="primary", use_container_width=True)
 
 
 
 
345
 
346
  if run_btn:
347
- ds_configs = [DATASET_PRESETS[k] for k in selected_datasets]
348
  results = []
349
  progress = st.progress(0, text="Starting...")
350
- total_steps = len(selected_models) * (len(ds_configs) + int(run_speed) + int(run_memory))
 
 
 
351
  step = 0
352
 
353
  for model_key in selected_models:
@@ -363,23 +620,58 @@ if run_btn:
363
  step / total_steps,
364
  text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
365
  )
366
- quality_results[ds_key] = cached_evaluate_quality(
367
- model, model_key,
368
- ds_cfg.name, ds_cfg.config, ds_cfg.split,
369
- ds_cfg.query_col, ds_cfg.passage_col,
370
- ds_cfg.score_col, ds_cfg.score_scale,
371
- max_pairs,
372
- )
 
 
 
 
 
 
 
373
  result["quality"] = quality_results
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  if run_speed:
376
  step += 1
377
  progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
378
  ds0 = ds_configs[0]
379
- corpus = cached_build_corpus(
380
- corpus_size, ds0.name, ds0.config, ds0.split,
381
- ds0.query_col, ds0.passage_col,
382
- )
 
 
 
383
  result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
384
 
385
  if run_memory:
@@ -387,10 +679,13 @@ if run_btn:
387
  progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
388
  from evals.memory import evaluate_memory
389
  ds0 = ds_configs[0]
390
- corpus = cached_build_corpus(
391
- corpus_size, ds0.name, ds0.config, ds0.split,
392
- ds0.query_col, ds0.passage_col,
393
- )
 
 
 
394
  result["memory_mb"] = evaluate_memory(
395
  cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
396
  )
@@ -412,7 +707,7 @@ if "results" not in st.session_state:
412
  "<div style='text-align:center; padding:3rem 0; color:#666;'>"
413
  "<p style='font-size:2.5rem; margin-bottom:0.5rem;'>📐</p>"
414
  "<p style='font-size:1.1rem;'>Configure models &amp; datasets in the sidebar,<br>"
415
- "then hit <b>Run Benchmark</b>.</p></div>",
416
  unsafe_allow_html=True,
417
  )
418
  st.stop()
@@ -434,8 +729,13 @@ for r in results:
434
  if ds_keys:
435
  first_ds = ds_keys[0]
436
  first_metrics_sample = results[0].get("quality", {}).get(first_ds, {})
437
- primary_metric = "spearman" if "spearman" in first_metrics_sample else "mrr"
438
- primary_label = "Spearman" if primary_metric == "spearman" else "MRR"
 
 
 
 
 
439
 
440
  scores = [
441
  (r["name"], r.get("quality", {}).get(first_ds, {}).get(primary_metric, 0))
@@ -524,47 +824,73 @@ for ds_key in ds_keys:
524
 
525
  if "spearman" in first_metrics:
526
  values = [r.get("quality", {}).get(ds_key, {}).get("spearman", 0) for r in results]
527
- fig, ax = plt.subplots(figsize=(4, 2.4))
528
- style_chart(fig, ax)
529
- bars = ax.bar(models, values, color="#4C72B0", edgecolor="#5a82c0", linewidth=0.5)
530
- ax.set_ylabel("Spearman", fontsize=8)
531
- ax.set_title(f"Quality — {ds_key}", fontsize=9, pad=8)
532
- ax.set_ylim(0, 1.08)
533
- for bar, v in zip(bars, values):
534
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
535
- f"{v:.4f}", ha="center", va="bottom", fontsize=7, color=CHART_TEXT)
536
- plt.xticks(rotation=30, ha="right")
537
- plt.tight_layout()
538
- st.pyplot(fig, use_container_width=False)
539
- plt.close(fig)
540
  else:
541
- metric_names = ["mrr", "recall@1", "recall@5", "recall@10"]
542
- x = np.arange(len(models))
543
- width = 0.18
544
- colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
545
-
546
- fig, ax = plt.subplots(figsize=(max(4, len(models) * 1.4), 3.0))
547
- style_chart(fig, ax)
548
- for i, (metric, color) in enumerate(zip(metric_names, colors)):
 
549
  values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
550
- offset = (i - 1.5) * width
551
- bars = ax.bar(x + offset, values, width, label=metric, color=color,
552
- edgecolor=color, linewidth=0.3, alpha=0.9)
553
- for bar, v in zip(bars, values):
554
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
555
- f"{v:.2f}", ha="center", va="bottom", fontsize=6, color=CHART_TEXT)
556
- ax.set_ylabel("Score", fontsize=8)
557
- ax.set_title(f"Retrieval Quality — {ds_key}", fontsize=9, pad=8)
558
- ax.set_ylim(0, 1.12)
559
- ax.set_xticks(x)
560
- ax.set_xticklabels(models, rotation=30, ha="right", fontsize=7)
561
- ax.legend(fontsize=6, ncol=4, loc="upper center",
562
- bbox_to_anchor=(0.5, -0.22),
563
- facecolor=CHART_BG, edgecolor="#444", labelcolor=CHART_TEXT)
564
- plt.tight_layout()
565
- fig.subplots_adjust(bottom=0.28)
566
- st.pyplot(fig, use_container_width=False)
567
- plt.close(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  # Speed & Memory side by side
570
  speed_values = [r.get("speed", {}).get("sentences_per_second", 0) for r in results]
@@ -577,36 +903,34 @@ if has_speed or has_memory:
577
 
578
  if has_speed:
579
  with cols[0]:
580
- fig, ax = plt.subplots(figsize=(3.5, 2.4))
581
- style_chart(fig, ax)
582
- bars = ax.bar(models, speed_values, color="#55A868", edgecolor="#65b878", linewidth=0.5)
583
- ax.set_ylabel("Sent / s", fontsize=8)
584
- ax.set_title("Encoding Speed", fontsize=9, pad=8)
585
- for bar, v in zip(bars, speed_values):
586
- if v > 0:
587
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
588
- str(v), ha="center", va="bottom", fontsize=7, color=CHART_TEXT)
589
- plt.xticks(rotation=30, ha="right")
590
- plt.tight_layout()
591
- st.pyplot(fig, use_container_width=False)
592
- plt.close(fig)
593
 
594
  if has_memory:
595
  col_idx = 1 if has_speed else 0
596
  with cols[col_idx]:
597
- fig, ax = plt.subplots(figsize=(3.5, 2.4))
598
- style_chart(fig, ax)
599
- bars = ax.bar(models, mem_values, color="#C44E52", edgecolor="#d45e62", linewidth=0.5)
600
- ax.set_ylabel("MB", fontsize=8)
601
- ax.set_title("Memory Usage", fontsize=9, pad=8)
602
- for bar, v in zip(bars, mem_values):
603
- if v > 0:
604
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
605
- str(v), ha="center", va="bottom", fontsize=7, color=CHART_TEXT)
606
- plt.xticks(rotation=30, ha="right")
607
- plt.tight_layout()
608
- st.pyplot(fig, use_container_width=False)
609
- plt.close(fig)
610
 
611
  # ---------------------------------------------------------------------------
612
  # Footer
 
2
 
3
  import io
4
  import csv
5
+ import re
6
  import time
7
 
 
8
  import numpy as np
9
+ import pandas as pd
10
+ import plotly.graph_objects as go
11
  import streamlit as st
12
 
13
  from datasets import load_dataset
14
 
15
  from corpus import build_corpus
16
  from dataset_config import DATASET_PRESETS, DatasetConfig
17
+ from evals.quality import ALL_RETRIEVAL_METRICS, DEFAULT_RETRIEVAL_METRICS, evaluate_quality
18
  from evals.speed import evaluate_speed
19
  from models import (
20
  REGISTRY,
 
111
 
112
  st.markdown("<hr class='section-divider'>", unsafe_allow_html=True)
113
 
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Helper: slugify a display name into a registry key
117
+ # ---------------------------------------------------------------------------
118
+ def _slugify(name: str) -> str:
119
+ s = name.strip().lower()
120
+ s = re.sub(r"[^a-z0-9]+", "-", s)
121
+ return s.strip("-")
122
+
123
+
124
  # ---------------------------------------------------------------------------
125
  # Sidebar — configuration
126
  # ---------------------------------------------------------------------------
127
  st.sidebar.markdown("### ⚙️ Configuration")
128
 
129
+ # ---- Models ---------------------------------------------------------------
130
  st.sidebar.markdown("**Models**")
131
  available_models = list(REGISTRY.keys())
132
  selected_models = st.sidebar.multiselect(
 
138
 
139
  with st.sidebar.expander("➕ Add Custom Model"):
140
  with st.form("add_model_form", clear_on_submit=True):
 
141
  new_name = st.text_input("Display name", placeholder="My Custom Model")
142
  new_model_id = st.text_input("HuggingFace model ID", placeholder="org/model-name")
143
  new_backend = st.selectbox("Backend", sorted(VALID_BACKENDS))
144
  new_gguf_file = st.text_input(
145
+ "GGUF filename", value="", placeholder="model.gguf",
146
+ help="Only needed for the gguf backend.",
147
  )
148
+ _adv_c1, _adv_c2 = st.columns(2)
149
+ new_is_baseline = _adv_c1.checkbox("Baseline", value=False)
150
+ new_persist = _adv_c2.checkbox("Save to disk", value=False,
151
+ help="Persist across sessions")
152
  submitted = st.form_submit_button("Add Model", use_container_width=True)
153
  if submitted:
154
+ new_key = _slugify(new_name) if new_name else ""
155
+ errors: list[str] = []
156
+ if not new_name:
157
+ errors.append("Display name is required.")
158
+ elif new_key in REGISTRY:
159
+ errors.append(f"A model named '{new_name}' already exists.")
160
+ if not new_model_id:
161
+ errors.append("HuggingFace model ID is required.")
162
+ elif "/" not in new_model_id:
163
+ errors.append("Model ID should be in `org/model-name` format.")
164
+ if new_backend == "gguf" and not new_gguf_file:
165
+ errors.append("GGUF filename is required for gguf backend.")
166
+ if errors:
167
+ for err in errors:
168
+ st.sidebar.error(err)
169
  else:
170
  cfg = ModelConfig(
171
  name=new_name,
 
182
  except ValueError as e:
183
  st.sidebar.error(str(e))
184
 
185
+ # ---- Datasets -------------------------------------------------------------
186
  st.sidebar.markdown("**Datasets**")
187
+
188
+ # Merge preset + user datasets (need this before the multiselect)
189
+ user_datasets: dict[str, DatasetConfig] = st.session_state.get("user_datasets", {})
190
+ all_datasets = {**DATASET_PRESETS, **user_datasets}
191
+
192
+ available_datasets = list(all_datasets.keys())
193
  selected_datasets = st.sidebar.multiselect(
194
+ "Select datasets",
195
  available_datasets,
196
+ default=["sts"] if "sts" in available_datasets else available_datasets[:1],
197
  label_visibility="collapsed",
198
  )
199
 
200
+ _MAX_UPLOAD_ROWS = 50_000
201
+ _MAX_UPLOAD_MB = 50
202
+
203
+ with st.sidebar.expander("➕ Add Dataset"):
204
+ ds_source = st.radio(
205
+ "Source", ["Upload file", "HuggingFace Hub"],
206
+ horizontal=True, label_visibility="collapsed",
207
+ )
208
+
209
+ if ds_source == "Upload file":
210
+ st.caption(
211
+ "CSV or TSV with query and passage columns. "
212
+ "Optional numeric score column enables Spearman correlation; "
213
+ "otherwise MRR & Recall@k are used. Max 50 MB / 50 k rows."
214
+ )
215
+ uploaded_file = st.file_uploader(
216
+ "Upload CSV or TSV", type=["csv", "tsv"], label_visibility="collapsed",
217
+ )
218
+ if uploaded_file is not None:
219
+ file_size_mb = uploaded_file.size / (1024 * 1024)
220
+ if file_size_mb > _MAX_UPLOAD_MB:
221
+ st.error(f"File too large ({file_size_mb:.1f} MB). Max {_MAX_UPLOAD_MB} MB.")
222
+ else:
223
+ sep = "\t" if uploaded_file.name.endswith(".tsv") else ","
224
+ try:
225
+ user_df = pd.read_csv(uploaded_file, sep=sep)
226
+ except Exception as e:
227
+ st.error(f"Failed to parse: {e}")
228
+ user_df = None
229
+
230
+ if user_df is not None:
231
+ errs: list[str] = []
232
+ if len(user_df.columns) < 2:
233
+ errs.append("Need at least 2 columns.")
234
+ if len(user_df) == 0:
235
+ errs.append("File is empty.")
236
+ if len(user_df) > _MAX_UPLOAD_ROWS:
237
+ errs.append(f"Too many rows ({len(user_df):,}). Max {_MAX_UPLOAD_ROWS:,}.")
238
+ if user_df.columns.duplicated().any():
239
+ errs.append("Duplicate column names.")
240
+ if errs:
241
+ for e in errs:
242
+ st.error(e)
243
+ else:
244
+ cols = list(user_df.columns)
245
+ st.dataframe(user_df.head(5), use_container_width=True, hide_index=True)
246
+
247
+ with st.form("add_dataset_form", clear_on_submit=False):
248
+ ds_label = st.text_input(
249
+ "Dataset name",
250
+ value=uploaded_file.name.rsplit(".", 1)[0],
251
+ )
252
+ user_query_col = st.selectbox("Query column", cols, index=0)
253
+ user_passage_col = st.selectbox(
254
+ "Passage column", cols, index=min(1, len(cols) - 1),
255
+ )
256
+ has_score = st.checkbox("Has score column")
257
+ user_score_col = st.selectbox(
258
+ "Score column", cols,
259
+ index=min(2, len(cols) - 1),
260
+ disabled=not has_score,
261
+ )
262
+ user_score_scale = st.number_input(
263
+ "Score scale (max value)",
264
+ min_value=1.0, value=5.0, step=1.0,
265
+ disabled=not has_score,
266
+ help="Scores divided by this to normalise to 0-1.",
267
+ )
268
+ ds_submitted = st.form_submit_button(
269
+ "Add Dataset", use_container_width=True,
270
+ )
271
+
272
+ if ds_submitted:
273
+ sub_errs: list[str] = []
274
+ if not ds_label:
275
+ sub_errs.append("Name is required.")
276
+ if user_query_col == user_passage_col:
277
+ sub_errs.append("Query and passage columns must differ.")
278
+ if has_score and user_score_col in (
279
+ user_query_col, user_passage_col,
280
+ ):
281
+ sub_errs.append("Score column must differ from query/passage.")
282
+ if user_df[user_query_col].astype(str).str.strip().eq("").all():
283
+ sub_errs.append(f"Query column '{user_query_col}' is empty.")
284
+ if user_df[user_passage_col].astype(str).str.strip().eq("").all():
285
+ sub_errs.append(f"Passage column '{user_passage_col}' is empty.")
286
+ if has_score:
287
+ try:
288
+ pd.to_numeric(user_df[user_score_col], errors="raise")
289
+ except (ValueError, TypeError):
290
+ sub_errs.append(f"Score column '{user_score_col}' must be numeric.")
291
+ if sub_errs:
292
+ for e in sub_errs:
293
+ st.error(e)
294
+ else:
295
+ data_dict = {c: user_df[c].astype(str).tolist() for c in cols}
296
+ if has_score:
297
+ data_dict[user_score_col] = [
298
+ float(v) for v in user_df[user_score_col]
299
+ ]
300
+ user_ds_cfg = DatasetConfig(
301
+ name=f"user/{ds_label}",
302
+ query_col=user_query_col,
303
+ passage_col=user_passage_col,
304
+ score_col=user_score_col if has_score else None,
305
+ score_scale=user_score_scale if has_score else 1.0,
306
+ data=data_dict,
307
+ )
308
+ if "user_datasets" not in st.session_state:
309
+ st.session_state["user_datasets"] = {}
310
+ st.session_state["user_datasets"][ds_label] = user_ds_cfg
311
+ st.success(f"Added **{ds_label}** ({len(user_df):,} rows)")
312
+
313
+ else: # HuggingFace Hub
314
+ st.caption("Load any dataset from [huggingface.co/datasets](https://huggingface.co/datasets).")
315
+ with st.form("add_hf_dataset_form", clear_on_submit=True):
316
+ hf_ds_label = st.text_input("Dataset name", placeholder="my-dataset")
317
+ hf_ds_id = st.text_input("HuggingFace ID", placeholder="org/dataset-name")
318
+ _hf_c1, _hf_c2 = st.columns(2)
319
+ hf_ds_config = _hf_c1.text_input("Config", value="", help="Leave blank if none.")
320
+ hf_ds_split = _hf_c2.text_input("Split", value="test")
321
+ hf_query_col = st.text_input("Query column", placeholder="query")
322
+ hf_passage_col = st.text_input("Passage column", placeholder="passage")
323
+ hf_has_score = st.checkbox("Has score column")
324
+ hf_score_col = st.text_input(
325
+ "Score column", placeholder="score", disabled=not hf_has_score,
326
+ )
327
+ hf_score_scale = st.number_input(
328
+ "Score scale (max value)", min_value=1.0, value=5.0, step=1.0,
329
+ disabled=not hf_has_score,
330
+ help="Scores divided by this to normalise to 0-1.",
331
+ )
332
+ hf_submitted = st.form_submit_button("Add Dataset", use_container_width=True)
333
+ if hf_submitted:
334
+ hf_errors: list[str] = []
335
+ if not hf_ds_label:
336
+ hf_errors.append("Dataset name is required.")
337
+ if not hf_ds_id:
338
+ hf_errors.append("HuggingFace ID is required.")
339
+ if not hf_query_col:
340
+ hf_errors.append("Query column is required.")
341
+ if not hf_passage_col:
342
+ hf_errors.append("Passage column is required.")
343
+ if hf_query_col and hf_passage_col and hf_query_col == hf_passage_col:
344
+ hf_errors.append("Query and passage columns must differ.")
345
+ if hf_has_score and not hf_score_col:
346
+ hf_errors.append("Score column is required when enabled.")
347
+ if hf_has_score and hf_score_col in (hf_query_col, hf_passage_col):
348
+ hf_errors.append("Score column must differ from query/passage.")
349
+
350
+ if hf_errors:
351
+ for err in hf_errors:
352
+ st.error(err)
353
+ else:
354
+ try:
355
+ _cfg_arg = hf_ds_config or None
356
+ _test_ds = load_dataset(hf_ds_id, _cfg_arg, split=hf_ds_split)
357
+ _ds_cols = _test_ds.column_names
358
+ _missing = [
359
+ c for c in [hf_query_col, hf_passage_col]
360
+ + ([hf_score_col] if hf_has_score else [])
361
+ if c not in _ds_cols
362
+ ]
363
+ if _missing:
364
+ st.error(
365
+ f"Column(s) not found: {', '.join(_missing)}. "
366
+ f"Available: {', '.join(_ds_cols)}"
367
+ )
368
+ else:
369
+ hf_ds_cfg = DatasetConfig(
370
+ name=hf_ds_id,
371
+ config=_cfg_arg,
372
+ split=hf_ds_split,
373
+ query_col=hf_query_col,
374
+ passage_col=hf_passage_col,
375
+ score_col=hf_score_col if hf_has_score else None,
376
+ score_scale=hf_score_scale if hf_has_score else 1.0,
377
+ )
378
+ if "user_datasets" not in st.session_state:
379
+ st.session_state["user_datasets"] = {}
380
+ st.session_state["user_datasets"][hf_ds_label] = hf_ds_cfg
381
+ st.success(f"Added **{hf_ds_label}**")
382
+ st.rerun()
383
+ except Exception as e:
384
+ st.error(f"Failed to load: {e}")
385
+
386
+ # ---- Evaluation options ---------------------------------------------------
387
+ _LLM_PROVIDERS = {"openai": "OpenAI", "anthropic": "Anthropic"}
388
+ _DEFAULT_MODELS = {"openai": "gpt-4o-mini", "anthropic": "claude-haiku-4-5-20251001"}
389
+
390
+ with st.sidebar.expander("⚙️ Evaluation"):
391
+ max_pairs = st.number_input(
392
+ "Max pairs per dataset",
393
+ min_value=100, max_value=50000, value=1000, step=100,
394
+ help="Caps the number of pairs evaluated per dataset.",
395
+ )
396
+
397
+ selected_metrics = st.multiselect(
398
+ "Retrieval metrics",
399
+ ALL_RETRIEVAL_METRICS,
400
+ default=DEFAULT_RETRIEVAL_METRICS,
401
+ help="Metrics for pair-based datasets (no score column). Scored datasets always use Spearman.",
402
+ )
403
+
404
+ st.markdown("---")
405
+ run_speed = st.checkbox("Speed benchmark")
406
+ run_memory = st.checkbox("Memory benchmark")
407
+
408
+ corpus_size = 500
409
+ num_runs = 3
410
+ batch_size = 64
411
+ if run_speed or run_memory:
412
+ _sp_c1, _sp_c2 = st.columns(2)
413
+ corpus_size = _sp_c1.number_input("Corpus size", 100, 10000, 500, step=100)
414
+ batch_size = _sp_c2.number_input("Batch size", 8, 512, 64, step=8)
415
+ if run_speed:
416
+ num_runs = st.number_input("Speed runs", 1, 10, 3)
417
+
418
+ st.markdown("---")
419
+ run_llm_judge = st.checkbox("LLM as a Judge")
420
+
421
+ llm_provider = "openai"
422
+ llm_api_key = ""
423
+ llm_model = ""
424
+ llm_max_samples = 50
425
+
426
+ if run_llm_judge:
427
+ st.caption(
428
+ "An LLM rates how relevant retrieved passages are to each query (1-5). "
429
+ "API charges apply."
430
+ )
431
+ llm_provider = st.selectbox(
432
+ "Provider", list(_LLM_PROVIDERS.keys()),
433
+ format_func=lambda k: _LLM_PROVIDERS[k],
434
+ )
435
+ llm_api_key = st.text_input(
436
+ "API key", type="password", placeholder="sk-...",
437
+ )
438
+ llm_model = st.text_input("Model", value=_DEFAULT_MODELS[llm_provider])
439
+ llm_max_samples = st.number_input(
440
+ "Samples to judge", min_value=5, max_value=500, value=50, step=5,
441
+ help="Queries sampled. Each = 5 API calls (top-5 passages).",
442
+ )
443
 
444
+ st.markdown("---")
445
+ _cache_c1, _cache_c2 = st.columns(2)
446
+ with _cache_c1:
447
+ if st.button("🗑 Clear All", use_container_width=True):
448
+ st.cache_resource.clear()
449
+ st.cache_data.clear()
450
+ for key in list(st.session_state.keys()):
451
+ del st.session_state[key]
452
+ st.rerun()
453
+ with _cache_c2:
454
+ if st.button("🔄 Results", use_container_width=True):
455
+ st.cache_data.clear()
456
+ for key in ["results", "selected_datasets"]:
457
+ st.session_state.pop(key, None)
458
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  # ---------------------------------------------------------------------------
461
  # Cached functions
 
486
  score_col: str | None,
487
  score_scale: float,
488
  max_pairs: int | None,
489
+ metrics: tuple[str, ...] | None = None,
490
  ) -> dict[str, float]:
491
+ """Cache quality results keyed by (model, dataset, max_pairs, metrics).
492
 
493
  The _model arg is excluded from the hash (underscore prefix).
494
  model_key is used as a hashable stand-in.
 
498
  query_col=query_col, passage_col=passage_col,
499
  score_col=score_col, score_scale=score_scale,
500
  )
501
+ return evaluate_quality(
502
+ _model, ds_cfg, max_pairs=max_pairs,
503
+ metrics=list(metrics) if metrics else None,
504
+ )
505
 
506
 
507
  @st.cache_data(show_spinner="Building corpus...", ttl=3600)
 
525
  for ds_key, metrics in r.get("quality", {}).items():
526
  for metric_name, value in metrics.items():
527
  flat[f"{ds_key}/{metric_name}"] = value
528
+ for ds_key, metrics in r.get("llm_judge", {}).items():
529
+ for metric_name, value in metrics.items():
530
+ flat[f"{ds_key}/{metric_name}"] = value
531
  speed = r.get("speed")
532
  if speed:
533
  flat["Speed (sent/s)"] = speed["sentences_per_second"]
 
565
 
566
 
567
  # ---------------------------------------------------------------------------
568
+ # Chart helpers
569
  # ---------------------------------------------------------------------------
570
  CHART_BG = "#0E1117"
571
+
572
+ _PLOTLY_LAYOUT = dict(
573
+ paper_bgcolor=CHART_BG,
574
+ plot_bgcolor=CHART_BG,
575
+ font=dict(color="#CCCCCC", size=11),
576
+ margin=dict(l=50, r=20, t=40, b=60),
577
+ bargap=0.25,
578
+ xaxis=dict(gridcolor="#2a2d35", zerolinecolor="#2a2d35"),
579
+ yaxis=dict(gridcolor="#2a2d35", zerolinecolor="#2a2d35"),
580
+ )
 
 
 
 
581
 
582
 
583
  # ---------------------------------------------------------------------------
 
591
  st.warning("Select at least one dataset from the sidebar.")
592
  st.stop()
593
 
594
+ if run_llm_judge and not llm_api_key:
595
+ st.warning("Enter an API key in the sidebar to use LLM judge evaluation.")
596
+ run_llm_judge = False
597
+
598
+ run_btn = st.sidebar.button("🚀 Run", type="primary", use_container_width=True)
599
 
600
  if run_btn:
601
+ ds_configs = [all_datasets[k] for k in selected_datasets]
602
  results = []
603
  progress = st.progress(0, text="Starting...")
604
+ total_steps = len(selected_models) * (
605
+ len(ds_configs) + int(run_speed) + int(run_memory)
606
+ + (len(ds_configs) if run_llm_judge else 0)
607
+ )
608
  step = 0
609
 
610
  for model_key in selected_models:
 
620
  step / total_steps,
621
  text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
622
  )
623
+ _metrics = selected_metrics or None
624
+ if ds_cfg.data is not None:
625
+ quality_results[ds_key] = evaluate_quality(
626
+ model, ds_cfg, max_pairs=max_pairs, metrics=_metrics,
627
+ )
628
+ else:
629
+ quality_results[ds_key] = cached_evaluate_quality(
630
+ model, model_key,
631
+ ds_cfg.name, ds_cfg.config, ds_cfg.split,
632
+ ds_cfg.query_col, ds_cfg.passage_col,
633
+ ds_cfg.score_col, ds_cfg.score_scale,
634
+ max_pairs,
635
+ metrics=tuple(_metrics) if _metrics else None,
636
+ )
637
  result["quality"] = quality_results
638
 
639
+ if run_llm_judge:
640
+ from evals.llm_judge import LLMJudgeConfig, evaluate_llm_judge
641
+ judge_cfg = LLMJudgeConfig(
642
+ provider=llm_provider,
643
+ api_key=llm_api_key,
644
+ model=llm_model,
645
+ max_samples=llm_max_samples,
646
+ )
647
+ judge_results = {}
648
+ for ds_cfg in ds_configs:
649
+ ds_key = ds_cfg.name.split("/")[-1]
650
+ step += 1
651
+ progress.progress(
652
+ step / total_steps,
653
+ text=f"LLM judge: **{cfg.name}** on *{ds_key}*...",
654
+ )
655
+ try:
656
+ judge_results[ds_key] = evaluate_llm_judge(
657
+ model, ds_cfg, judge_cfg, max_pairs=max_pairs,
658
+ )
659
+ except Exception as e:
660
+ st.warning(f"LLM judge failed for {cfg.name}/{ds_key}: {e}")
661
+ judge_results[ds_key] = {}
662
+ result["llm_judge"] = judge_results
663
+
664
  if run_speed:
665
  step += 1
666
  progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
667
  ds0 = ds_configs[0]
668
+ if ds0.data is not None:
669
+ corpus = build_corpus(corpus_size, ds0)
670
+ else:
671
+ corpus = cached_build_corpus(
672
+ corpus_size, ds0.name, ds0.config, ds0.split,
673
+ ds0.query_col, ds0.passage_col,
674
+ )
675
  result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
676
 
677
  if run_memory:
 
679
  progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
680
  from evals.memory import evaluate_memory
681
  ds0 = ds_configs[0]
682
+ if ds0.data is not None:
683
+ corpus = build_corpus(corpus_size, ds0)
684
+ else:
685
+ corpus = cached_build_corpus(
686
+ corpus_size, ds0.name, ds0.config, ds0.split,
687
+ ds0.query_col, ds0.passage_col,
688
+ )
689
  result["memory_mb"] = evaluate_memory(
690
  cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
691
  )
 
707
  "<div style='text-align:center; padding:3rem 0; color:#666;'>"
708
  "<p style='font-size:2.5rem; margin-bottom:0.5rem;'>📐</p>"
709
  "<p style='font-size:1.1rem;'>Configure models &amp; datasets in the sidebar,<br>"
710
+ "then hit <b>Run Evaluation</b>.</p></div>",
711
  unsafe_allow_html=True,
712
  )
713
  st.stop()
 
729
  if ds_keys:
730
  first_ds = ds_keys[0]
731
  first_metrics_sample = results[0].get("quality", {}).get(first_ds, {})
732
+ if "spearman" in first_metrics_sample:
733
+ primary_metric = "spearman"
734
+ primary_label = "Spearman"
735
+ else:
736
+ # Use the first available retrieval metric
737
+ primary_metric = next(iter(first_metrics_sample), "mrr")
738
+ primary_label = primary_metric.upper()
739
 
740
  scores = [
741
  (r["name"], r.get("quality", {}).get(first_ds, {}).get(primary_metric, 0))
 
824
 
825
  if "spearman" in first_metrics:
826
  values = [r.get("quality", {}).get(ds_key, {}).get("spearman", 0) for r in results]
827
+ fig = go.Figure(go.Bar(
828
+ x=models, y=values,
829
+ marker_color="#4C72B0",
830
+ text=[f"{v:.4f}" for v in values],
831
+ textposition="outside",
832
+ ))
833
+ fig.update_layout(
834
+ **_PLOTLY_LAYOUT,
835
+ title=f"Quality — {ds_key}",
836
+ yaxis_title="Spearman",
837
+ yaxis_range=[0, 1.08],
838
+ )
839
+ st.plotly_chart(fig, use_container_width=True)
840
  else:
841
+ metric_names = list(first_metrics.keys())
842
+ _palette = [
843
+ "#4C72B0", "#55A868", "#C44E52", "#8172B2",
844
+ "#E5AE38", "#DD8452", "#64B5CD", "#8C8C8C",
845
+ "#D4A6C8", "#6ACC65", "#D65F5F",
846
+ ]
847
+ fig = go.Figure()
848
+ for i, metric in enumerate(metric_names):
849
+ color = _palette[i % len(_palette)]
850
  values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
851
+ fig.add_trace(go.Bar(
852
+ name=metric, x=models, y=values,
853
+ marker_color=color,
854
+ text=[f"{v:.2f}" for v in values],
855
+ textposition="outside",
856
+ ))
857
+ fig.update_layout(
858
+ **_PLOTLY_LAYOUT,
859
+ title=f"Retrieval Quality — {ds_key}",
860
+ yaxis_title="Score",
861
+ yaxis_range=[0, 1.12],
862
+ barmode="group",
863
+ legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5),
864
+ )
865
+ st.plotly_chart(fig, use_container_width=True)
866
+
867
+ # LLM Judge charts
868
+ for ds_key in ds_keys:
869
+ has_judge = any(r.get("llm_judge", {}).get(ds_key) for r in results)
870
+ if not has_judge:
871
+ continue
872
+ judge_metrics = ["judge_avg@1", "judge_avg@5", "judge_ndcg@5"]
873
+ judge_labels = ["Avg@1", "Avg@5", "nDCG@5"]
874
+ colors = ["#E5AE38", "#DD8452", "#C44E52"]
875
+
876
+ fig = go.Figure()
877
+ for metric, label, color in zip(judge_metrics, judge_labels, colors):
878
+ values = [r.get("llm_judge", {}).get(ds_key, {}).get(metric, 0) for r in results]
879
+ fig.add_trace(go.Bar(
880
+ name=label, x=models, y=values,
881
+ marker_color=color,
882
+ text=[f"{v:.2f}" for v in values],
883
+ textposition="outside",
884
+ ))
885
+ fig.update_layout(
886
+ **_PLOTLY_LAYOUT,
887
+ title=f"LLM Judge — {ds_key}",
888
+ yaxis_title="Score",
889
+ yaxis_range=[0, 1.12],
890
+ barmode="group",
891
+ legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.5),
892
+ )
893
+ st.plotly_chart(fig, use_container_width=True)
894
 
895
  # Speed & Memory side by side
896
  speed_values = [r.get("speed", {}).get("sentences_per_second", 0) for r in results]
 
903
 
904
  if has_speed:
905
  with cols[0]:
906
+ fig = go.Figure(go.Bar(
907
+ x=models, y=speed_values,
908
+ marker_color="#55A868",
909
+ text=[str(v) if v > 0 else "" for v in speed_values],
910
+ textposition="outside",
911
+ ))
912
+ fig.update_layout(
913
+ **_PLOTLY_LAYOUT,
914
+ title="Encoding Speed",
915
+ yaxis_title="Sent / s",
916
+ )
917
+ st.plotly_chart(fig, use_container_width=True)
 
918
 
919
  if has_memory:
920
  col_idx = 1 if has_speed else 0
921
  with cols[col_idx]:
922
+ fig = go.Figure(go.Bar(
923
+ x=models, y=mem_values,
924
+ marker_color="#C44E52",
925
+ text=[str(v) if v > 0 else "" for v in mem_values],
926
+ textposition="outside",
927
+ ))
928
+ fig.update_layout(
929
+ **_PLOTLY_LAYOUT,
930
+ title="Memory Usage",
931
+ yaxis_title="MB",
932
+ )
933
+ st.plotly_chart(fig, use_container_width=True)
 
934
 
935
  # ---------------------------------------------------------------------------
936
  # Footer
corpus.py CHANGED
@@ -9,8 +9,12 @@ def build_corpus(size: int, ds_cfg: DatasetConfig | None = None) -> list[str]:
9
  """Build a corpus of real sentences from the configured dataset."""
10
  if ds_cfg is None:
11
  ds_cfg = DatasetConfig()
12
- dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
13
- sentences = list(dataset[ds_cfg.query_col]) + list(dataset[ds_cfg.passage_col])
 
 
 
 
14
  full: list[str] = []
15
  while len(full) < size:
16
  full.extend(sentences)
 
9
  """Build a corpus of real sentences from the configured dataset."""
10
  if ds_cfg is None:
11
  ds_cfg = DatasetConfig()
12
+ if ds_cfg.data is not None:
13
+ data = ds_cfg.data
14
+ else:
15
+ dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
16
+ data = {col: list(dataset[col]) for col in dataset.column_names}
17
+ sentences = list(data[ds_cfg.query_col]) + list(data[ds_cfg.passage_col])
18
  full: list[str] = []
19
  while len(full) < size:
20
  full.extend(sentences)
dataset_config.py CHANGED
@@ -1,6 +1,6 @@
1
  from __future__ import annotations
2
 
3
- from dataclasses import dataclass
4
 
5
 
6
  @dataclass
@@ -14,6 +14,8 @@ class DatasetConfig:
14
  passage_col: str = "sentence2"
15
  score_col: str | None = "score"
16
  score_scale: float = 5.0
 
 
17
 
18
 
19
  DATASET_PRESETS: dict[str, DatasetConfig] = {
 
1
  from __future__ import annotations
2
 
3
+ from dataclasses import dataclass, field
4
 
5
 
6
  @dataclass
 
14
  passage_col: str = "sentence2"
15
  score_col: str | None = "score"
16
  score_scale: float = 5.0
17
+ # Pre-loaded data (dict of column-name -> list). When set, skip HF download.
18
+ data: dict[str, list] | None = field(default=None, repr=False)
19
 
20
 
21
  DATASET_PRESETS: dict[str, DatasetConfig] = {
evals/llm_judge.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ import urllib.request
6
+ import urllib.error
7
+ from dataclasses import dataclass
8
+
9
+ import numpy as np
10
+
11
+ from dataset_config import DatasetConfig
12
+
13
+
14
+ @dataclass
15
+ class LLMJudgeConfig:
16
+ provider: str # "openai" or "anthropic"
17
+ api_key: str
18
+ model: str
19
+ max_samples: int = 50
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Provider-specific API calls
24
+ # ---------------------------------------------------------------------------
25
+
26
+ _SYSTEM_PROMPT = (
27
+ "You are an impartial relevance judge. Given a query and a passage, "
28
+ "rate how relevant the passage is to the query on a scale of 1 to 5.\n\n"
29
+ "1 = Completely irrelevant\n"
30
+ "2 = Slightly relevant\n"
31
+ "3 = Moderately relevant\n"
32
+ "4 = Highly relevant\n"
33
+ "5 = Perfectly relevant\n\n"
34
+ "Respond with ONLY a single integer (1-5), nothing else."
35
+ )
36
+
37
+
38
+ def _build_user_prompt(query: str, passage: str) -> str:
39
+ return f"Query: {query}\n\nPassage: {passage}"
40
+
41
+
42
+ def _call_openai(api_key: str, model: str, query: str, passage: str) -> int:
43
+ body = json.dumps({
44
+ "model": model,
45
+ "messages": [
46
+ {"role": "system", "content": _SYSTEM_PROMPT},
47
+ {"role": "user", "content": _build_user_prompt(query, passage)},
48
+ ],
49
+ "max_tokens": 4,
50
+ "temperature": 0.0,
51
+ }).encode()
52
+ req = urllib.request.Request(
53
+ "https://api.openai.com/v1/chat/completions",
54
+ data=body,
55
+ headers={
56
+ "Authorization": f"Bearer {api_key}",
57
+ "Content-Type": "application/json",
58
+ },
59
+ )
60
+ with urllib.request.urlopen(req, timeout=30) as resp:
61
+ data = json.loads(resp.read())
62
+ text = data["choices"][0]["message"]["content"].strip()
63
+ return _parse_score(text)
64
+
65
+
66
+ def _call_anthropic(api_key: str, model: str, query: str, passage: str) -> int:
67
+ body = json.dumps({
68
+ "model": model,
69
+ "max_tokens": 4,
70
+ "system": _SYSTEM_PROMPT,
71
+ "messages": [
72
+ {"role": "user", "content": _build_user_prompt(query, passage)},
73
+ ],
74
+ }).encode()
75
+ req = urllib.request.Request(
76
+ "https://api.anthropic.com/v1/messages",
77
+ data=body,
78
+ headers={
79
+ "x-api-key": api_key,
80
+ "anthropic-version": "2023-06-01",
81
+ "Content-Type": "application/json",
82
+ },
83
+ )
84
+ with urllib.request.urlopen(req, timeout=30) as resp:
85
+ data = json.loads(resp.read())
86
+ text = data["content"][0]["text"].strip()
87
+ return _parse_score(text)
88
+
89
+
90
+ def _parse_score(text: str) -> int:
91
+ for ch in text:
92
+ if ch.isdigit() and ch in "12345":
93
+ return int(ch)
94
+ return 3 # fallback to neutral
95
+
96
+
97
+ _PROVIDERS = {
98
+ "openai": _call_openai,
99
+ "anthropic": _call_anthropic,
100
+ }
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # Main evaluation entry point
105
+ # ---------------------------------------------------------------------------
106
+
107
+ def evaluate_llm_judge(
108
+ model,
109
+ ds_cfg: DatasetConfig,
110
+ judge_cfg: LLMJudgeConfig,
111
+ max_pairs: int | None = None,
112
+ progress_callback=None,
113
+ ) -> dict[str, float]:
114
+ """Use an LLM to judge retrieval relevance for top-k results.
115
+
116
+ For each sampled query, retrieves the top-5 passages by embedding
117
+ similarity and asks the LLM to rate each one. Returns average
118
+ relevance scores at different cut-offs.
119
+ """
120
+ from datasets import load_dataset
121
+
122
+ if ds_cfg.data is not None:
123
+ data = ds_cfg.data
124
+ else:
125
+ dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
126
+ data = {col: list(dataset[col]) for col in dataset.column_names}
127
+
128
+ queries = list(data[ds_cfg.query_col])
129
+ passages = list(data[ds_cfg.passage_col])
130
+
131
+ if max_pairs is not None and len(queries) > max_pairs:
132
+ queries = queries[:max_pairs]
133
+ passages = passages[:max_pairs]
134
+
135
+ # Encode
136
+ emb_q = model.encode(queries, is_query=True)
137
+ emb_p = model.encode(passages, is_query=False)
138
+
139
+ # Normalise
140
+ emb_q = emb_q / np.linalg.norm(emb_q, axis=1, keepdims=True)
141
+ emb_p = emb_p / np.linalg.norm(emb_p, axis=1, keepdims=True)
142
+
143
+ # Sample queries to judge
144
+ n = len(queries)
145
+ sample_size = min(judge_cfg.max_samples, n)
146
+ sample_indices = sorted(random.sample(range(n), sample_size))
147
+
148
+ call_fn = _PROVIDERS[judge_cfg.provider]
149
+ top_k = 5
150
+
151
+ # For each sampled query, get top-k passages and judge them
152
+ relevance_at_k: list[list[int]] = [] # shape: (sample_size, top_k)
153
+ total_calls = sample_size * top_k
154
+ calls_done = 0
155
+
156
+ for idx in sample_indices:
157
+ query_emb = emb_q[idx : idx + 1]
158
+ sims = (query_emb @ emb_p.T).flatten()
159
+ top_indices = np.argsort(-sims)[:top_k]
160
+
161
+ scores_for_query = []
162
+ for passage_idx in top_indices:
163
+ try:
164
+ score = call_fn(
165
+ judge_cfg.api_key, judge_cfg.model,
166
+ queries[idx], passages[int(passage_idx)],
167
+ )
168
+ except Exception:
169
+ score = 0 # treat API errors as 0
170
+ scores_for_query.append(score)
171
+ calls_done += 1
172
+ if progress_callback:
173
+ progress_callback(calls_done, total_calls)
174
+ relevance_at_k.append(scores_for_query)
175
+
176
+ arr = np.array(relevance_at_k, dtype=float) # (sample_size, top_k)
177
+
178
+ # Normalise scores to 0-1 (from 1-5 scale)
179
+ arr_norm = (arr - 1.0) / 4.0
180
+
181
+ # nDCG@5
182
+ def _dcg(scores: np.ndarray) -> np.ndarray:
183
+ positions = np.arange(1, scores.shape[1] + 1)
184
+ return np.sum(scores / np.log2(positions + 1), axis=1)
185
+
186
+ dcg = _dcg(arr_norm)
187
+ ideal = _dcg(np.sort(arr_norm, axis=1)[:, ::-1])
188
+ ndcg = np.where(ideal > 0, dcg / ideal, 0.0)
189
+
190
+ return {
191
+ "judge_avg@1": round(float(np.mean(arr_norm[:, 0])), 4),
192
+ "judge_avg@5": round(float(np.mean(arr_norm)), 4),
193
+ "judge_ndcg@5": round(float(np.mean(ndcg)), 4),
194
+ }
evals/quality.py CHANGED
@@ -13,8 +13,26 @@ def _normalize(emb: np.ndarray) -> np.ndarray:
13
  return emb / norms
14
 
15
 
16
- def _retrieval_metrics(emb_q: np.ndarray, emb_p: np.ndarray) -> dict[str, float]:
17
- """Compute MRR and Recall@k assuming query i matches passage i."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  emb_q = _normalize(emb_q)
19
  emb_p = _normalize(emb_p)
20
 
@@ -22,40 +40,61 @@ def _retrieval_metrics(emb_q: np.ndarray, emb_p: np.ndarray) -> dict[str, float]
22
  sims = emb_q @ emb_p.T
23
 
24
  n = sims.shape[0]
25
- # For each query, rank passages by descending similarity
26
- # ranks[i] = rank of the correct passage (0-indexed)
27
  sorted_indices = np.argsort(-sims, axis=1)
28
  ranks = np.array([int(np.where(sorted_indices[i] == i)[0][0]) for i in range(n)])
29
 
30
- mrr = float(np.mean(1.0 / (ranks + 1)))
31
- recall_1 = float(np.mean(ranks < 1))
32
- recall_5 = float(np.mean(ranks < 5))
33
- recall_10 = float(np.mean(ranks < 10))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- return {
36
- "mrr": round(mrr, 4),
37
- "recall@1": round(recall_1, 4),
38
- "recall@5": round(recall_5, 4),
39
- "recall@10": round(recall_10, 4),
40
- }
41
 
42
 
43
  def evaluate_quality(
44
  model,
45
  ds_cfg: DatasetConfig | None = None,
46
  max_pairs: int | None = None,
 
47
  ) -> dict[str, float]:
48
  """Evaluate embedding quality on a dataset.
49
 
50
  Returns a dict with either {"spearman": float} for scored datasets
51
- or {"mrr", "recall@1", "recall@5", "recall@10"} for pair datasets.
52
  """
53
  if ds_cfg is None:
54
  ds_cfg = DatasetConfig()
55
 
56
- dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
57
- queries = list(dataset[ds_cfg.query_col])
58
- passages = list(dataset[ds_cfg.passage_col])
 
 
 
 
59
 
60
  if max_pairs is not None and len(queries) > max_pairs:
61
  queries = queries[:max_pairs]
@@ -66,7 +105,7 @@ def evaluate_quality(
66
 
67
  if ds_cfg.score_col is not None:
68
  # Scored mode: Spearman correlation
69
- scores = list(dataset[ds_cfg.score_col])
70
  if max_pairs is not None and len(scores) > max_pairs:
71
  scores = scores[:max_pairs]
72
  gold_scores = [s / ds_cfg.score_scale for s in scores]
@@ -79,4 +118,4 @@ def evaluate_quality(
79
  return {"spearman": round(float(correlation), 4)}
80
 
81
  # Pair mode: retrieval metrics
82
- return _retrieval_metrics(emb_q, emb_p)
 
13
  return emb / norms
14
 
15
 
16
+ ALL_RETRIEVAL_METRICS = [
17
+ "mrr",
18
+ "map@5", "map@10",
19
+ "ndcg@5", "ndcg@10",
20
+ "precision@1", "precision@5", "precision@10",
21
+ "recall@1", "recall@5", "recall@10",
22
+ ]
23
+
24
+ DEFAULT_RETRIEVAL_METRICS = ["mrr", "recall@1", "recall@5", "recall@10"]
25
+
26
+
27
+ def _retrieval_metrics(
28
+ emb_q: np.ndarray,
29
+ emb_p: np.ndarray,
30
+ metrics: list[str] | None = None,
31
+ ) -> dict[str, float]:
32
+ """Compute retrieval metrics assuming query i matches passage i."""
33
+ if metrics is None:
34
+ metrics = DEFAULT_RETRIEVAL_METRICS
35
+
36
  emb_q = _normalize(emb_q)
37
  emb_p = _normalize(emb_p)
38
 
 
40
  sims = emb_q @ emb_p.T
41
 
42
  n = sims.shape[0]
 
 
43
  sorted_indices = np.argsort(-sims, axis=1)
44
  ranks = np.array([int(np.where(sorted_indices[i] == i)[0][0]) for i in range(n)])
45
 
46
+ results: dict[str, float] = {}
47
+
48
+ for m in metrics:
49
+ if m == "mrr":
50
+ results["mrr"] = round(float(np.mean(1.0 / (ranks + 1))), 4)
51
+
52
+ elif m.startswith("recall@"):
53
+ k = int(m.split("@")[1])
54
+ results[m] = round(float(np.mean(ranks < k)), 4)
55
+
56
+ elif m.startswith("precision@"):
57
+ k = int(m.split("@")[1])
58
+ # Single relevant doc per query: precision@k = 1/k if hit, else 0
59
+ results[m] = round(float(np.mean((ranks < k) / k)), 4)
60
+
61
+ elif m.startswith("map@"):
62
+ k = int(m.split("@")[1])
63
+ # Single relevant doc: AP = 1/(rank+1) if rank < k, else 0
64
+ ap = np.where(ranks < k, 1.0 / (ranks + 1), 0.0)
65
+ results[m] = round(float(np.mean(ap)), 4)
66
+
67
+ elif m.startswith("ndcg@"):
68
+ k = int(m.split("@")[1])
69
+ # Single relevant doc: DCG = 1/log2(rank+2) if rank < k, else 0
70
+ # ideal DCG = 1/log2(2) = 1.0
71
+ dcg = np.where(ranks < k, 1.0 / np.log2(ranks + 2), 0.0)
72
+ results[m] = round(float(np.mean(dcg)), 4)
73
 
74
+ return results
 
 
 
 
 
75
 
76
 
77
  def evaluate_quality(
78
  model,
79
  ds_cfg: DatasetConfig | None = None,
80
  max_pairs: int | None = None,
81
+ metrics: list[str] | None = None,
82
  ) -> dict[str, float]:
83
  """Evaluate embedding quality on a dataset.
84
 
85
  Returns a dict with either {"spearman": float} for scored datasets
86
+ or selected retrieval metrics for pair datasets.
87
  """
88
  if ds_cfg is None:
89
  ds_cfg = DatasetConfig()
90
 
91
+ if ds_cfg.data is not None:
92
+ data = ds_cfg.data
93
+ else:
94
+ dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
95
+ data = {col: list(dataset[col]) for col in dataset.column_names}
96
+ queries = list(data[ds_cfg.query_col])
97
+ passages = list(data[ds_cfg.passage_col])
98
 
99
  if max_pairs is not None and len(queries) > max_pairs:
100
  queries = queries[:max_pairs]
 
105
 
106
  if ds_cfg.score_col is not None:
107
  # Scored mode: Spearman correlation
108
+ scores = list(data[ds_cfg.score_col])
109
  if max_pairs is not None and len(scores) > max_pairs:
110
  scores = scores[:max_pairs]
111
  gold_scores = [s / ds_cfg.score_scale for s in scores]
 
118
  return {"spearman": round(float(correlation), 4)}
119
 
120
  # Pair mode: retrieval metrics
121
+ return _retrieval_metrics(emb_q, emb_p, metrics=metrics)
requirements.txt CHANGED
@@ -7,5 +7,5 @@ fastembed
7
  libembedding
8
  numpy
9
  scipy
10
- matplotlib
11
  streamlit
 
7
  libembedding
8
  numpy
9
  scipy
10
+ plotly
11
  streamlit