Commit ·
0f31e57
1
Parent(s): 465f2c6
4096-release (#1)
Browse files- Release: S23DR 2026 learned baseline (HSS=0.382) (43975eb3f3e3ba3f9773ef8afda9a3d0e85b2f9c)
- .gitignore +5 -0
- REPRODUCE.md +194 -0
- checkpoint.pt +2 -2
- configs/base.json +39 -0
- repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_args.json +66 -0
- repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_final.pt +3 -0
- repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_args.json +66 -0
- repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_final.pt +3 -0
- repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_args.json +66 -0
- repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_final.pt +3 -0
- repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_args.json +66 -0
- repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_final.pt +3 -0
- repro_runs/deterministic_hss372/20260330_071030_8c95_3610_args.json +66 -0
- repro_runs/deterministic_hss372/20260330_071030_8c95_3610_final.pt +3 -0
- repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_args.json +66 -0
- repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_final.pt +3 -0
- repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_args.json +66 -0
- repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_final.pt +3 -0
- repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_args.json +66 -0
- repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_final.pt +3 -0
- repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_args.json +66 -0
- repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_final.pt +3 -0
- reproduce.sh +68 -0
- reproduce_deterministic.sh +71 -0
- s23dr_2026_example/attention.py +0 -85
- s23dr_2026_example/cache_scenes.py +0 -195
- s23dr_2026_example/color_mappings.py +0 -26
- s23dr_2026_example/data.py +3 -13
- s23dr_2026_example/losses.py +10 -106
- s23dr_2026_example/make_sampled_cache.py +0 -185
- s23dr_2026_example/model.py +4 -181
- s23dr_2026_example/sinkhorn.py +0 -55
- s23dr_2026_example/soft_hss_loss.py +0 -507
- s23dr_2026_example/train.py +530 -0
- s23dr_2026_example/varifold.py +9 -152
- s23dr_2026_example/wire_varifold_kernels.py +2 -295
- script.py +5 -5
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
runs/
|
| 4 |
+
*.png
|
| 5 |
+
*.log
|
REPRODUCE.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reproducing the Best Checkpoint (HSS=0.382)
|
| 2 |
+
|
| 3 |
+
## Quick Start
|
| 4 |
+
|
| 5 |
+
The `checkpoint.pt` in this repo is the final model. To run inference:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
python script.py
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
To reproduce from scratch (~3hr on 1x RTX 4090):
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
bash reproduce.sh
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Exact Recipe
|
| 18 |
+
|
| 19 |
+
Architecture (unchanged across all 3 steps):
|
| 20 |
+
```
|
| 21 |
+
Perceiver: hidden=256, ff=1024, latent_tokens=256, latent_layers=7
|
| 22 |
+
encoder_layers=4, decoder_layers=3, cross_attn_interval=4
|
| 23 |
+
num_heads=4, kv_heads_cross=2, kv_heads_self=2
|
| 24 |
+
qk_norm=True (L2), rms_norm=True, dropout=0.1
|
| 25 |
+
segments=64, segment_param=midpoint_dir_len, segment_conf=True
|
| 26 |
+
behind_emb_dim=8, vote_features=True, activation=gelu
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
All shared config lives in `configs/base.json`.
|
| 30 |
+
|
| 31 |
+
### Step 1: 2048 Phase 1 (from scratch) — ~1.5hr
|
| 32 |
+
|
| 33 |
+
```
|
| 34 |
+
Data: hf://usm3d/s23dr-2026-sampled_2048_v2:train (16,508 samples)
|
| 35 |
+
Steps: 0 -> 125,000 (242 epochs)
|
| 36 |
+
LR: 3e-4, warmup=10,000
|
| 37 |
+
Batch size: 32
|
| 38 |
+
Optimizer: AdamW, betas=(0.9, 0.95), weight_decay=0.01
|
| 39 |
+
Sinkhorn: eps=0.1, iters=20, dustbin=0.3
|
| 40 |
+
Conf: weight=0.1, mode=sinkhorn, head_wd=0.1
|
| 41 |
+
Endpoint: OFF
|
| 42 |
+
Aug: rotate=True, flip=True
|
| 43 |
+
Seed: 353
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Trains the perceiver from random init on 2048-point samples. The sinkhorn
|
| 47 |
+
optimal transport loss learns to match predicted segments to ground truth.
|
| 48 |
+
|
| 49 |
+
**Why 2048 first:** Training directly on 4096 overfits (1.47x train/val ratio
|
| 50 |
+
vs 1.19x for 2048). The 2048 model learns better-generalized representations.
|
| 51 |
+
|
| 52 |
+
**Output:** HSS ~0.28.
|
| 53 |
+
|
| 54 |
+
### Step 2: 4096 finetune (constant LR) — ~15min
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
Resume: Step 1 -> step125000.pt
|
| 58 |
+
Data: hf://usm3d/s23dr-2026-sampled_4096_v2:train (15,892 samples)
|
| 59 |
+
Steps: 125,001 -> 135,000 (10k steps)
|
| 60 |
+
LR: 3e-5 (constant, no cooldown)
|
| 61 |
+
Batch size: 64
|
| 62 |
+
Endpoint: OFF
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Switches input from 2048 to 4096 points, increasing structural coverage from
|
| 66 |
+
66% to 74%. The gentle lr (3e-5) preserves learned representations while
|
| 67 |
+
adapting to the extra input. Higher LR (>1e-4) causes catastrophic forgetting.
|
| 68 |
+
|
| 69 |
+
HSS jumps from 0.28 to 0.35 in ~5k steps. Plateaus by 10k steps.
|
| 70 |
+
|
| 71 |
+
**Output:** HSS ~0.35.
|
| 72 |
+
|
| 73 |
+
### Step 3: Cooldown with endpoint loss — ~1hr
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
Resume: Step 2 -> step135000.pt
|
| 77 |
+
Data: hf://usm3d/s23dr-2026-sampled_4096_v2:train
|
| 78 |
+
Steps: 135,001 -> 170,000 (35k steps)
|
| 79 |
+
LR: 3e-5, cooldown_start=150,000, cooldown_steps=20,000
|
| 80 |
+
(constant 3e-5 for 15k steps, then linear decay to ~0 over 20k)
|
| 81 |
+
Batch size: 64
|
| 82 |
+
Endpoint: weight=0.1
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Adds symmetric endpoint L1 loss (using detached sinkhorn assignment) to
|
| 86 |
+
tighten vertex precision. The sinkhorn loss alone operates on segment
|
| 87 |
+
midpoint/direction/length and doesn't directly penalize endpoint position error.
|
| 88 |
+
|
| 89 |
+
**Output:** HSS=0.382, F1=0.414.
|
| 90 |
+
|
| 91 |
+
### Key Numbers
|
| 92 |
+
|
| 93 |
+
| Stage | Steps | HSS | F1 | What changed |
|
| 94 |
+
|-------|-------|-----|-----|-------------|
|
| 95 |
+
| After Step 1 | 125k | 0.281 | 0.156 | Learned geometry from 2048 pts |
|
| 96 |
+
| After Step 2 | 135k | 0.351 | 0.190 | +74% coverage from 4096 pts |
|
| 97 |
+
| After Step 3 | 170k | **0.382** | **0.411** | Vertex precision from endpoint loss |
|
| 98 |
+
|
| 99 |
+
## Why This Works
|
| 100 |
+
|
| 101 |
+
1. **2048 training has low overfitting** (1.19x train/val ratio) — the model
|
| 102 |
+
learns good representations without memorizing training samples.
|
| 103 |
+
|
| 104 |
+
2. **4096 data has higher coverage ceiling** (74% vs 66% structural points) —
|
| 105 |
+
more of the building surface is observed, improving vertex recall.
|
| 106 |
+
|
| 107 |
+
3. **Gentle finetuning preserves representations** — at lr=3e-5, the model
|
| 108 |
+
keeps its learned geometry understanding while adapting to the extra input.
|
| 109 |
+
|
| 110 |
+
4. **Endpoint loss tightens vertices** — the symmetric endpoint distance
|
| 111 |
+
directly penalizes vertex position errors, which sinkhorn loss alone
|
| 112 |
+
doesn't do (it operates on midpoint/direction/length parametrization).
|
| 113 |
+
|
| 114 |
+
## What Doesn't Work
|
| 115 |
+
|
| 116 |
+
- **Training 4096 from scratch:** overfits (1.47x train/val gap), peaks at 0.346
|
| 117 |
+
- **BuildingWorld pretraining:** representations are orthogonal to S23DR (cosine sim = 0.05)
|
| 118 |
+
- **Mixed BW+S23DR training:** BW data hurts due to domain gap
|
| 119 |
+
- **High dropout / weight decay:** prevents overfitting but causes underfitting
|
| 120 |
+
- **High finetune LR (>1e-4):** catastrophic forgetting of 2048 representations
|
| 121 |
+
- **Steeper cooldown (1e-5, 20x drop):** slightly worse than 3e-5 for this checkpoint
|
| 122 |
+
|
| 123 |
+
## Reproduction Results
|
| 124 |
+
|
| 125 |
+
### End-to-end reproductions
|
| 126 |
+
|
| 127 |
+
| Model | HSS | F1 | IoU | Notes |
|
| 128 |
+
|-------|-----|-----|-----|-------|
|
| 129 |
+
| Original | 0.382 | 0.414 | 0.370 | Shipped checkpoint |
|
| 130 |
+
| E2E repro #4 | 0.379 | 0.409 | 0.369 | Closest E2E, `repro_runs/e2e_repro4_hss379/` |
|
| 131 |
+
| Compiled repro (from submission codebase) | 0.376 | — | — | Best compiled repro from this codebase, `repro_runs/compiled_repro_hss376/` |
|
| 132 |
+
| E2E repro #3 | 0.375 | 0.404 | 0.367 | |
|
| 133 |
+
| Deterministic E2E | 0.372 | 0.398 | 0.368 | Bit-reproducible, `repro_runs/deterministic_hss372/` |
|
| 134 |
+
| E2E repro #5 | 0.349 | 0.373 | — | Outlier (early compile divergence) |
|
| 135 |
+
|
| 136 |
+
### Partial reproductions (isolating pipeline stages)
|
| 137 |
+
|
| 138 |
+
| Test | Starting from | HSS | Gap to original |
|
| 139 |
+
|------|--------------|-----|-----------------|
|
| 140 |
+
| Step 3 from orig Step 2 (run A) | Original step135000.pt | 0.382 | 0.000 |
|
| 141 |
+
| Step 3 from orig Step 2 (run B) | Original step135000.pt | 0.384 | +0.002 |
|
| 142 |
+
| Step 2+3 from orig Step 1 | Original step125000.pt | 0.377 | -0.005 |
|
| 143 |
+
| Step 1 from orig step 100k | Original step100000.pt | 0.285 (Step 1 HSS) | +0.004 vs 0.281 |
|
| 144 |
+
|
| 145 |
+
Step 3 from the same checkpoint reproduces to within 0.002. The E2E variance
|
| 146 |
+
(0.349-0.379) is dominated by torch.compile nondeterminism in Step 1.
|
| 147 |
+
|
| 148 |
+
### All benchmarks
|
| 149 |
+
|
| 150 |
+
| Model | Input | HSS | F1 | IoU | Notes |
|
| 151 |
+
|-------|-------|-----|-----|-----|-------|
|
| 152 |
+
| Handcrafted baseline | raw views | 0.307 | 0.404 | 0.260 | |
|
| 153 |
+
| h256+qk+ep (submitted) | 2048 | 0.365 | 0.388 | 0.360 | HSS=0.427 on test |
|
| 154 |
+
| Original 3-step | 2048 | 0.373 | 0.404 | 0.363 | |
|
| 155 |
+
| Original 3-step | 4096 | 0.382 | 0.414 | 0.370 | Best ever |
|
| 156 |
+
| Step3 repro from orig S2 | 4096 | 0.384 | 0.414 | — | Near-exact repro |
|
| 157 |
+
| E2E repro #4 | 4096 | 0.379 | 0.409 | 0.369 | |
|
| 158 |
+
| Compiled repro (submission codebase) | 4096 | 0.376 | — | — | Best compiled from this exact codebase |
|
| 159 |
+
| E2E repro #3 | 4096 | 0.375 | 0.404 | 0.367 | |
|
| 160 |
+
| Deterministic E2E | 4096 | 0.372 | 0.398 | 0.368 | Bit-reproducible |
|
| 161 |
+
|
| 162 |
+
## Code Equivalence Verification
|
| 163 |
+
|
| 164 |
+
| Test | Result |
|
| 165 |
+
|------|--------|
|
| 166 |
+
| Forward pass (same checkpoint, same input) | Bit-identical (0.00 diff) |
|
| 167 |
+
| Loss computation | Bit-identical (0.00 diff) |
|
| 168 |
+
| Gradient computation | 5e-8 max diff |
|
| 169 |
+
| Training from same seed | Bit-identical steps 1-44 |
|
| 170 |
+
| Step 3 from same checkpoint (2 runs) | HSS=0.382, 0.384 |
|
| 171 |
+
| Deterministic mode (2 runs) | Bit-identical (0.00 diff) |
|
| 172 |
+
|
| 173 |
+
## Reproducibility Notes
|
| 174 |
+
|
| 175 |
+
**Default mode** (`reproduce.sh`): Uses torch.compile (~3x faster). Each run
|
| 176 |
+
gets different Triton kernels, causing ~1e-8 floating-point divergence at a
|
| 177 |
+
random step (31-45). This grows through chaotic SGD dynamics, giving HSS
|
| 178 |
+
variance of ~0.03 across runs. E2E reproductions land in the 0.349-0.379 range.
|
| 179 |
+
|
| 180 |
+
**Deterministic mode** (`--deterministic` flag): Disables torch.compile.
|
| 181 |
+
Bit-identical across runs with the same seed. HSS=0.372 (slightly lower than
|
| 182 |
+
compiled mode because eager-mode kernels follow a different numerical path).
|
| 183 |
+
|
| 184 |
+
**bad_samples.txt**: The shipped file has 156 entries to match original training.
|
| 185 |
+
(Note: `wc -l` reports 155 because the last line lacks a trailing newline.)
|
| 186 |
+
Two additional bad samples (`47b0e0ce19b`, `4b2d56eb3ef`) were discovered after
|
| 187 |
+
the original training run. They are legitimately bad (misaligned GT) but were
|
| 188 |
+
included in the original training data. Adding them changes the batch iteration
|
| 189 |
+
order and costs ~0.005 HSS in deterministic mode (0.372 -> 0.367) and ~0.04 in
|
| 190 |
+
compiled mode due to compounded torch.compile variance. Participants training
|
| 191 |
+
from scratch may wish to add these 2 entries for cleaner training data, but
|
| 192 |
+
should expect slightly different scores due to the changed iteration order.
|
| 193 |
+
|
| 194 |
+
The shipped `checkpoint.pt` is from the original training run (HSS=0.382).
|
checkpoint.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1296423a1a2e603ba55860d8ef8fa3a861764a7bbc3de96b776fca59cf5b11ab
|
| 3 |
+
size 106429791
|
configs/base.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"arch": "perceiver",
|
| 3 |
+
"segments": 64,
|
| 4 |
+
"hidden": 256,
|
| 5 |
+
"ff": 1024,
|
| 6 |
+
"num_heads": 4,
|
| 7 |
+
"kv_heads_cross": 2,
|
| 8 |
+
"kv_heads_self": 2,
|
| 9 |
+
"latent_tokens": 256,
|
| 10 |
+
"latent_layers": 7,
|
| 11 |
+
"decoder_layers": 3,
|
| 12 |
+
"cross_attn_interval": 4,
|
| 13 |
+
"encoder_layers": 4,
|
| 14 |
+
"behind_emb_dim": 8,
|
| 15 |
+
"dropout": 0.1,
|
| 16 |
+
"activation": "gelu",
|
| 17 |
+
"rms_norm": true,
|
| 18 |
+
"qk_norm": true,
|
| 19 |
+
"qk_norm_type": "l2",
|
| 20 |
+
"segment_param": "midpoint_dir_len",
|
| 21 |
+
"segment_conf": true,
|
| 22 |
+
"vote_features": true,
|
| 23 |
+
|
| 24 |
+
"adam_betas": "0.9,0.95",
|
| 25 |
+
"weight_decay": 0.01,
|
| 26 |
+
"warmup": 10000,
|
| 27 |
+
"varifold_weight": 0.0,
|
| 28 |
+
"sinkhorn_weight": 1.0,
|
| 29 |
+
"sinkhorn_eps": 0.1,
|
| 30 |
+
"sinkhorn_iters": 20,
|
| 31 |
+
"sinkhorn_dustbin": 0.3,
|
| 32 |
+
"conf_weight": 0.1,
|
| 33 |
+
"conf_mode": "sinkhorn",
|
| 34 |
+
"conf_head_wd": 0.1,
|
| 35 |
+
|
| 36 |
+
"aug_rotate": true,
|
| 37 |
+
"aug_flip": true,
|
| 38 |
+
"seed": 353
|
| 39 |
+
}
|
repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 32,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_2048_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 0,
|
| 18 |
+
"cooldown_steps": 0,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": false,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.0,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "5b37dfc70c392936631b59d0bab24f20e4a2b0d9",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 0.0003,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "runs/validate_155_compiled",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 2048,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 125000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e27c8ae20c291676a2c0b7e080d6d00be86f251ae6bdfe3cc3ff6f27f6646b00
|
| 3 |
+
size 106427231
|
repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 64,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 0,
|
| 18 |
+
"cooldown_steps": 0,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": false,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.0,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "5b37dfc70c392936631b59d0bab24f20e4a2b0d9",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 3e-05,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "runs/validate_155_compiled",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "runs/validate_155_compiled/20260408_173614_64c7_4670/checkpoints/step125000.pt",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 4096,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 135000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89bae7e879900c128bbfe1a05e2f6d4b8430675ed95d47ea2d975b198c71cdad
|
| 3 |
+
size 106429599
|
repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 64,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 150000,
|
| 18 |
+
"cooldown_steps": 20000,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": false,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.1,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "5b37dfc70c392936631b59d0bab24f20e4a2b0d9",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 3e-05,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "runs/validate_155_compiled",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "runs/validate_155_compiled/20260408_194447_3061_6284/checkpoints/step135000.pt",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 4096,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 170000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb5779baa160aab1f55b1c698cc39da6a72cae5c9bf5456a2b4063f2342b85a3
|
| 3 |
+
size 106429599
|
repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 32,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_2048_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 0,
|
| 18 |
+
"cooldown_steps": 0,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": true,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.0,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 0.0003,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "/workspace/s23dr_2026_example/repro_deterministic",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 2048,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 125000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4bd05b2b323ced2ed94c4bd382f0e07ce5fae382f19ea35cb456ff631bcd1ac0
|
| 3 |
+
size 106423583
|
repro_runs/deterministic_hss372/20260330_071030_8c95_3610_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 64,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 0,
|
| 18 |
+
"cooldown_steps": 0,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": true,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.0,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 3e-05,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "/workspace/s23dr_2026_example/repro_deterministic",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "/workspace/s23dr_2026_example/repro_deterministic/20260330_025738_f0c9_3400/checkpoints/step125000.pt",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 4096,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 135000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/deterministic_hss372/20260330_071030_8c95_3610_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:992dbe9c7bfb4713b72d0f444f480f4a53e10844330d426d3fe0f367cdb96441
|
| 3 |
+
size 106425951
|
repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 64,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 150000,
|
| 18 |
+
"cooldown_steps": 20000,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": true,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.1,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 3e-05,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "/workspace/s23dr_2026_example/repro_deterministic",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "/workspace/s23dr_2026_example/repro_deterministic/20260330_071030_8c95_3610/checkpoints/step135000.pt",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 4096,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 170000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd7f508eb05e42ae70efb64fd8b3ab17d036000d49589b9122d4a1a2429c35db
|
| 3 |
+
size 106425951
|
repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 32,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_2048_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 0,
|
| 18 |
+
"cooldown_steps": 0,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": false,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.0,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 0.0003,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "/workspace/s23dr_2026_example/repro_e2e_run4",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 2048,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 125000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8ba72f11560b8164c87cc839e3070a37e024a2a618b7820d4819b739902aa2b
|
| 3 |
+
size 106427231
|
repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 64,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 0,
|
| 18 |
+
"cooldown_steps": 0,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": false,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.0,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 3e-05,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "/workspace/s23dr_2026_example/repro_e2e_run4",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "/workspace/s23dr_2026_example/repro_e2e_run4/20260329_213417_ef91_6503/checkpoints/step125000.pt",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 4096,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 135000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d2f2f135f8de8b676f35e5fe2693c56c2ed060649967ff877e63385d607992f
|
| 3 |
+
size 106429663
|
repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_args.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"adam_betas": "0.9,0.95",
|
| 4 |
+
"arch": "perceiver",
|
| 5 |
+
"args_from": "configs/base.json",
|
| 6 |
+
"aug_drop": 0.0,
|
| 7 |
+
"aug_flip": true,
|
| 8 |
+
"aug_jitter": 0.0,
|
| 9 |
+
"aug_rotate": true,
|
| 10 |
+
"batch_size": 64,
|
| 11 |
+
"behind_emb_dim": 8,
|
| 12 |
+
"cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
|
| 13 |
+
"conf_clamp_min": null,
|
| 14 |
+
"conf_head_wd": 0.1,
|
| 15 |
+
"conf_mode": "sinkhorn",
|
| 16 |
+
"conf_weight": 0.1,
|
| 17 |
+
"cooldown_start": 150000,
|
| 18 |
+
"cooldown_steps": 20000,
|
| 19 |
+
"cosine_decay": false,
|
| 20 |
+
"cpu": false,
|
| 21 |
+
"cross_attn_interval": 4,
|
| 22 |
+
"decoder_input_xattn": false,
|
| 23 |
+
"decoder_layers": 3,
|
| 24 |
+
"deterministic": false,
|
| 25 |
+
"dropout": 0.1,
|
| 26 |
+
"ema_decay": 0.0,
|
| 27 |
+
"encoder_layers": 4,
|
| 28 |
+
"endpoint_warmup": 0,
|
| 29 |
+
"endpoint_weight": 0.1,
|
| 30 |
+
"ff": 1024,
|
| 31 |
+
"git_dirty": true,
|
| 32 |
+
"git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
|
| 33 |
+
"hidden": 256,
|
| 34 |
+
"kv_heads_cross": 2,
|
| 35 |
+
"kv_heads_self": 2,
|
| 36 |
+
"latent_layers": 7,
|
| 37 |
+
"latent_tokens": 256,
|
| 38 |
+
"learnable_fourier": false,
|
| 39 |
+
"length_floor": 0.0,
|
| 40 |
+
"lr": 3e-05,
|
| 41 |
+
"num_heads": 4,
|
| 42 |
+
"out_dir": "/workspace/s23dr_2026_example/repro_e2e_run4",
|
| 43 |
+
"pre_encoder_layers": 0,
|
| 44 |
+
"qk_norm": true,
|
| 45 |
+
"qk_norm_type": "l2",
|
| 46 |
+
"resume": "/workspace/s23dr_2026_example/repro_e2e_run4/20260330_002648_ca92_4553/checkpoints/step135000.pt",
|
| 47 |
+
"rms_norm": true,
|
| 48 |
+
"seed": 353,
|
| 49 |
+
"segment_conf": true,
|
| 50 |
+
"segment_param": "midpoint_dir_len",
|
| 51 |
+
"segments": 64,
|
| 52 |
+
"seq_len": 4096,
|
| 53 |
+
"sinkhorn_dustbin": 0.3,
|
| 54 |
+
"sinkhorn_eps": 0.1,
|
| 55 |
+
"sinkhorn_eps_schedule": "none",
|
| 56 |
+
"sinkhorn_eps_start": null,
|
| 57 |
+
"sinkhorn_iters": 20,
|
| 58 |
+
"sinkhorn_weight": 1.0,
|
| 59 |
+
"steps": 170000,
|
| 60 |
+
"val_cache_dir": "",
|
| 61 |
+
"varifold_cross_only": false,
|
| 62 |
+
"varifold_weight": 0.0,
|
| 63 |
+
"vote_features": true,
|
| 64 |
+
"warmup": 10000,
|
| 65 |
+
"weight_decay": 0.01
|
| 66 |
+
}
|
repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_final.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:161888591f9066bd2c3f42400839a46755bec5e72dc5f1245d7b3336c1a7ddc2
|
| 3 |
+
size 106429663
|
reproduce.sh
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Reproduce the best checkpoint (HSS=0.382) from scratch.
|
| 3 |
+
#
|
| 4 |
+
# Three stages:
|
| 5 |
+
# 1. Train on 2048-point data (~1.5hr on 1x RTX 4090)
|
| 6 |
+
# 2. Finetune on 4096-point data (~15min)
|
| 7 |
+
# 3. Cooldown with endpoint loss (~1hr)
|
| 8 |
+
#
|
| 9 |
+
# Total: ~3hr on a single GPU (plus ~30min for compilation + data loading).
|
| 10 |
+
# All shared config lives in configs/base.json.
|
| 11 |
+
# Each step only specifies what changes.
|
| 12 |
+
set -e
|
| 13 |
+
|
| 14 |
+
OUT_DIR="${1:-runs}"
|
| 15 |
+
BASE="--args-from configs/base.json"
|
| 16 |
+
|
| 17 |
+
# ============================================================
|
| 18 |
+
# Step 1: Train on 2048-point data (Phase 1)
|
| 19 |
+
# ============================================================
|
| 20 |
+
echo "=== Step 1: Training on 2048 data ==="
|
| 21 |
+
python -m s23dr_2026_example.train $BASE \
|
| 22 |
+
--cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train \
|
| 23 |
+
--seq-len 2048 \
|
| 24 |
+
--lr 3e-4 \
|
| 25 |
+
--batch-size 32 \
|
| 26 |
+
--steps 125000 \
|
| 27 |
+
--out-dir "$OUT_DIR"
|
| 28 |
+
|
| 29 |
+
STEP1_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
|
| 30 |
+
echo "Step 1 complete: $STEP1_DIR"
|
| 31 |
+
|
| 32 |
+
# ============================================================
|
| 33 |
+
# Step 2: Finetune on 4096-point data
|
| 34 |
+
# ============================================================
|
| 35 |
+
echo "=== Step 2: Finetuning on 4096 data ==="
|
| 36 |
+
python -m s23dr_2026_example.train $BASE \
|
| 37 |
+
--cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
|
| 38 |
+
--resume "$STEP1_DIR/checkpoints/step125000.pt" \
|
| 39 |
+
--seq-len 4096 \
|
| 40 |
+
--lr 3e-5 \
|
| 41 |
+
--batch-size 64 \
|
| 42 |
+
--steps 135000 \
|
| 43 |
+
--out-dir "$OUT_DIR"
|
| 44 |
+
|
| 45 |
+
STEP2_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
|
| 46 |
+
echo "Step 2 complete: $STEP2_DIR"
|
| 47 |
+
|
| 48 |
+
# ============================================================
|
| 49 |
+
# Step 3: Cooldown with endpoint loss
|
| 50 |
+
# ============================================================
|
| 51 |
+
echo "=== Step 3: Cooldown with endpoint loss ==="
|
| 52 |
+
python -m s23dr_2026_example.train $BASE \
|
| 53 |
+
--cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
|
| 54 |
+
--resume "$STEP2_DIR/checkpoints/step135000.pt" \
|
| 55 |
+
--seq-len 4096 \
|
| 56 |
+
--lr 3e-5 \
|
| 57 |
+
--batch-size 64 \
|
| 58 |
+
--endpoint-weight 0.1 \
|
| 59 |
+
--cooldown-start 150000 \
|
| 60 |
+
--cooldown-steps 20000 \
|
| 61 |
+
--steps 170000 \
|
| 62 |
+
--out-dir "$OUT_DIR"
|
| 63 |
+
|
| 64 |
+
STEP3_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
|
| 65 |
+
echo "Step 3 complete: $STEP3_DIR"
|
| 66 |
+
echo ""
|
| 67 |
+
echo "Final checkpoint: $STEP3_DIR/checkpoints/final.pt"
|
| 68 |
+
echo "Copy to checkpoint.pt for submission."
|
reproduce_deterministic.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Reproduce the best checkpoint in deterministic mode (bit-reproducible).
|
| 3 |
+
#
|
| 4 |
+
# Same three stages as reproduce.sh, but with --deterministic:
|
| 5 |
+
# 1. Train on 2048-point data (~3hr on 1x RTX 4090)
|
| 6 |
+
# 2. Finetune on 4096-point data (~30min)
|
| 7 |
+
# 3. Cooldown with endpoint loss (~2hr)
|
| 8 |
+
#
|
| 9 |
+
# Total: ~5.5hr on a single GPU (no torch.compile, ~2x slower than reproduce.sh).
|
| 10 |
+
# Deterministic mode disables torch.compile and forces CUDA deterministic ops.
|
| 11 |
+
# Results are bit-identical across runs with the same seed. Expected HSS ~0.372.
|
| 12 |
+
set -e
|
| 13 |
+
|
| 14 |
+
OUT_DIR="${1:-runs}"
|
| 15 |
+
BASE="--args-from configs/base.json"
|
| 16 |
+
|
| 17 |
+
# ============================================================
|
| 18 |
+
# Step 1: Train on 2048-point data (Phase 1)
|
| 19 |
+
# ============================================================
|
| 20 |
+
echo "=== Step 1: Training on 2048 data (deterministic) ==="
|
| 21 |
+
python -m s23dr_2026_example.train $BASE \
|
| 22 |
+
--cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train \
|
| 23 |
+
--seq-len 2048 \
|
| 24 |
+
--lr 3e-4 \
|
| 25 |
+
--batch-size 32 \
|
| 26 |
+
--steps 125000 \
|
| 27 |
+
--deterministic \
|
| 28 |
+
--out-dir "$OUT_DIR"
|
| 29 |
+
|
| 30 |
+
STEP1_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
|
| 31 |
+
echo "Step 1 complete: $STEP1_DIR"
|
| 32 |
+
|
| 33 |
+
# ============================================================
|
| 34 |
+
# Step 2: Finetune on 4096-point data
|
| 35 |
+
# ============================================================
|
| 36 |
+
echo "=== Step 2: Finetuning on 4096 data (deterministic) ==="
|
| 37 |
+
python -m s23dr_2026_example.train $BASE \
|
| 38 |
+
--cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
|
| 39 |
+
--resume "$STEP1_DIR/checkpoints/step125000.pt" \
|
| 40 |
+
--seq-len 4096 \
|
| 41 |
+
--lr 3e-5 \
|
| 42 |
+
--batch-size 64 \
|
| 43 |
+
--steps 135000 \
|
| 44 |
+
--deterministic \
|
| 45 |
+
--out-dir "$OUT_DIR"
|
| 46 |
+
|
| 47 |
+
STEP2_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
|
| 48 |
+
echo "Step 2 complete: $STEP2_DIR"
|
| 49 |
+
|
| 50 |
+
# ============================================================
|
| 51 |
+
# Step 3: Cooldown with endpoint loss
|
| 52 |
+
# ============================================================
|
| 53 |
+
echo "=== Step 3: Cooldown with endpoint loss (deterministic) ==="
|
| 54 |
+
python -m s23dr_2026_example.train $BASE \
|
| 55 |
+
--cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
|
| 56 |
+
--resume "$STEP2_DIR/checkpoints/step135000.pt" \
|
| 57 |
+
--seq-len 4096 \
|
| 58 |
+
--lr 3e-5 \
|
| 59 |
+
--batch-size 64 \
|
| 60 |
+
--endpoint-weight 0.1 \
|
| 61 |
+
--cooldown-start 150000 \
|
| 62 |
+
--cooldown-steps 20000 \
|
| 63 |
+
--steps 170000 \
|
| 64 |
+
--deterministic \
|
| 65 |
+
--out-dir "$OUT_DIR"
|
| 66 |
+
|
| 67 |
+
STEP3_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
|
| 68 |
+
echo "Step 3 complete: $STEP3_DIR"
|
| 69 |
+
echo ""
|
| 70 |
+
echo "Final checkpoint: $STEP3_DIR/checkpoints/final.pt"
|
| 71 |
+
echo "Copy to checkpoint.pt for submission."
|
s23dr_2026_example/attention.py
CHANGED
|
@@ -139,88 +139,3 @@ class FeedForward(nn.Module):
|
|
| 139 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 140 |
x = self.linear1(x)
|
| 141 |
return self.linear2(self.activation(x))
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
# =============================================================================
|
| 145 |
-
# Custom Transformer Block
|
| 146 |
-
# =============================================================================
|
| 147 |
-
|
| 148 |
-
class TransformerBlock(nn.Module):
|
| 149 |
-
"""
|
| 150 |
-
Single transformer block combining:
|
| 151 |
-
- multi-head SDPA (non-causal)
|
| 152 |
-
- layernorm + residual
|
| 153 |
-
- feed-forward MLP + residual
|
| 154 |
-
"""
|
| 155 |
-
def __init__(
|
| 156 |
-
self,
|
| 157 |
-
d_model: int,
|
| 158 |
-
num_heads: int,
|
| 159 |
-
dim_ff: int,
|
| 160 |
-
dropout: float = 0.0,
|
| 161 |
-
activation: str = "gelu",
|
| 162 |
-
kv_heads: int = None,
|
| 163 |
-
):
|
| 164 |
-
super().__init__()
|
| 165 |
-
self.norm1 = nn.LayerNorm(d_model)
|
| 166 |
-
self.norm2 = nn.LayerNorm(d_model)
|
| 167 |
-
|
| 168 |
-
self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads)
|
| 169 |
-
self.dropout1 = nn.Dropout(dropout)
|
| 170 |
-
self.ffn = FeedForward(d_model, dim_ff, activation=activation)
|
| 171 |
-
self.dropout2 = nn.Dropout(dropout)
|
| 172 |
-
|
| 173 |
-
def forward(
|
| 174 |
-
self,
|
| 175 |
-
x: torch.Tensor,
|
| 176 |
-
memory: torch.Tensor,
|
| 177 |
-
memory_key_padding_mask: torch.Tensor | None = None,
|
| 178 |
-
) -> torch.Tensor:
|
| 179 |
-
res = x
|
| 180 |
-
x = self.norm1(x)
|
| 181 |
-
x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask)
|
| 182 |
-
x = res + self.dropout1(x)
|
| 183 |
-
|
| 184 |
-
res = x
|
| 185 |
-
x = self.norm2(x)
|
| 186 |
-
x = self.ffn(x)
|
| 187 |
-
return res + self.dropout2(x)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
class TransformerDecoderSets(nn.Module):
|
| 191 |
-
"""
|
| 192 |
-
A stack of TransformerBlock layers for set-to-set
|
| 193 |
-
modeling without causal masks.
|
| 194 |
-
"""
|
| 195 |
-
def __init__(
|
| 196 |
-
self,
|
| 197 |
-
d_model: int,
|
| 198 |
-
num_heads: int,
|
| 199 |
-
dim_ff: int,
|
| 200 |
-
num_layers: int,
|
| 201 |
-
dropout: float = 0.0,
|
| 202 |
-
activation: str = "gelu",
|
| 203 |
-
kv_heads: int = None,
|
| 204 |
-
):
|
| 205 |
-
super().__init__()
|
| 206 |
-
self.layers = nn.ModuleList([
|
| 207 |
-
TransformerBlock(
|
| 208 |
-
d_model,
|
| 209 |
-
num_heads,
|
| 210 |
-
dim_ff,
|
| 211 |
-
dropout=dropout,
|
| 212 |
-
activation=activation,
|
| 213 |
-
kv_heads=kv_heads,
|
| 214 |
-
)
|
| 215 |
-
for _ in range(num_layers)
|
| 216 |
-
])
|
| 217 |
-
|
| 218 |
-
def forward(
|
| 219 |
-
self,
|
| 220 |
-
tgt: torch.Tensor,
|
| 221 |
-
memory: torch.Tensor,
|
| 222 |
-
memory_key_padding_mask: torch.Tensor | None = None,
|
| 223 |
-
) -> torch.Tensor:
|
| 224 |
-
for layer in self.layers:
|
| 225 |
-
tgt = layer(tgt, memory, memory_key_padding_mask=memory_key_padding_mask)
|
| 226 |
-
return tgt
|
|
|
|
| 139 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 140 |
x = self.linear1(x)
|
| 141 |
return self.linear2(self.activation(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/cache_scenes.py
CHANGED
|
@@ -23,24 +23,9 @@ Cache format per file (.pt):
|
|
| 23 |
"""
|
| 24 |
from __future__ import annotations
|
| 25 |
|
| 26 |
-
import sys
|
| 27 |
-
from pathlib import Path as _Path
|
| 28 |
-
if __package__ is None or __package__ == "":
|
| 29 |
-
_here = _Path(__file__).resolve().parent
|
| 30 |
-
if str(_here.parent) not in sys.path:
|
| 31 |
-
sys.path.insert(0, str(_here.parent))
|
| 32 |
-
__package__ = _here.name
|
| 33 |
-
|
| 34 |
-
import argparse
|
| 35 |
-
import time
|
| 36 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 37 |
-
from pathlib import Path
|
| 38 |
-
|
| 39 |
import numpy as np
|
| 40 |
-
import torch
|
| 41 |
|
| 42 |
from .point_fusion import (
|
| 43 |
-
FuserConfig, build_compact_scene,
|
| 44 |
GEST_ID_TO_NAME, ADE_ID_TO_NAME, NUM_GEST,
|
| 45 |
)
|
| 46 |
|
|
@@ -191,183 +176,3 @@ def _compute_smart_center_scale(xyz, source, mad_k=2.5, percentile=95.0,
|
|
| 191 |
return center.astype(np.float32), np.float32(scale)
|
| 192 |
|
| 193 |
|
| 194 |
-
def _process_one(sample, cfg):
|
| 195 |
-
"""Process a single HF sample into a cache dict. Returns (order_id, dict) or None."""
|
| 196 |
-
rng = np.random.RandomState() # worker-local rng
|
| 197 |
-
|
| 198 |
-
n_edges = len(sample.get("wf_edges", []))
|
| 199 |
-
if n_edges == 0 or n_edges > 64:
|
| 200 |
-
return None
|
| 201 |
-
|
| 202 |
-
scene = build_compact_scene(sample, cfg, rng=rng)
|
| 203 |
-
if scene is None:
|
| 204 |
-
return None
|
| 205 |
-
|
| 206 |
-
gt_v = scene.get("gt_vertices")
|
| 207 |
-
gt_e = scene.get("gt_edges")
|
| 208 |
-
if gt_v is None or gt_e is None or len(gt_e) == 0:
|
| 209 |
-
return None
|
| 210 |
-
|
| 211 |
-
xyz = scene["xyz"]
|
| 212 |
-
source = scene["source"]
|
| 213 |
-
visible_src = scene["visible_src"]
|
| 214 |
-
visible_id = scene["visible_id"]
|
| 215 |
-
behind_id = scene["behind_gest_id"]
|
| 216 |
-
|
| 217 |
-
group_id, class_id = _compute_group_and_class(
|
| 218 |
-
visible_src, visible_id, behind_id, source
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
center, scale = _compute_smart_center_scale(xyz, source)
|
| 222 |
-
|
| 223 |
-
order_id = sample.get("order_id", "unknown")
|
| 224 |
-
|
| 225 |
-
return order_id, {
|
| 226 |
-
"xyz": xyz.astype(np.float32),
|
| 227 |
-
"source": source.astype(np.uint8),
|
| 228 |
-
"group_id": group_id,
|
| 229 |
-
"class_id": class_id,
|
| 230 |
-
"behind_gest_id": behind_id.astype(np.int16),
|
| 231 |
-
"visible_src": visible_src.astype(np.uint8),
|
| 232 |
-
"visible_id": visible_id.astype(np.int16),
|
| 233 |
-
"n_views_voted": scene["n_views_voted"],
|
| 234 |
-
"vote_frac": scene["vote_frac"],
|
| 235 |
-
"center": center,
|
| 236 |
-
"scale": scale,
|
| 237 |
-
"gt_vertices": gt_v.astype(np.float32),
|
| 238 |
-
"gt_edges": gt_e.astype(np.int32),
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def main():
|
| 243 |
-
p = argparse.ArgumentParser(description="Cache compact scenes from HoHo22k")
|
| 244 |
-
g = p.add_mutually_exclusive_group(required=True)
|
| 245 |
-
g.add_argument("--data-dir", help="Local dir with shards")
|
| 246 |
-
g.add_argument("--streaming", action="store_true", help="Stream from HuggingFace")
|
| 247 |
-
p.add_argument("--out-dir", required=True, help="Output directory for .pt files")
|
| 248 |
-
p.add_argument("--limit", type=int, default=0)
|
| 249 |
-
p.add_argument("--depth-per-view", type=int, default=8000)
|
| 250 |
-
p.add_argument("--workers", type=int, default=0,
|
| 251 |
-
help="Parallel workers (0=sequential)")
|
| 252 |
-
p.add_argument("--skip-existing", action="store_true",
|
| 253 |
-
help="Skip samples whose .pt already exists in out-dir")
|
| 254 |
-
p.add_argument("--shard-start", type=int, default=0,
|
| 255 |
-
help="First shard index (for parallel launches)")
|
| 256 |
-
p.add_argument("--shard-stride", type=int, default=1,
|
| 257 |
-
help="Stride between shards (e.g. 8 means take every 8th shard)")
|
| 258 |
-
args = p.parse_args()
|
| 259 |
-
|
| 260 |
-
out_dir = Path(args.out_dir)
|
| 261 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 262 |
-
existing_ids = set(p.stem for p in out_dir.glob("*.pt")) if args.skip_existing else set()
|
| 263 |
-
|
| 264 |
-
# Load dataset
|
| 265 |
-
from datasets import load_dataset
|
| 266 |
-
if args.streaming:
|
| 267 |
-
ds = load_dataset(
|
| 268 |
-
"usm3d/hoho22k_2026_trainval",
|
| 269 |
-
streaming=True, trust_remote_code=True, split="train",
|
| 270 |
-
)
|
| 271 |
-
else:
|
| 272 |
-
data_root = Path(args.data_dir).resolve()
|
| 273 |
-
tars = []
|
| 274 |
-
for candidate in [data_root / "data" / "train", data_root / "train", data_root]:
|
| 275 |
-
if candidate.exists():
|
| 276 |
-
tars = sorted(str(p) for p in candidate.glob("*.tar"))
|
| 277 |
-
if tars:
|
| 278 |
-
break
|
| 279 |
-
loader = None
|
| 280 |
-
for c in [data_root / "hoho22k_2026_trainval.py"]:
|
| 281 |
-
if c.exists():
|
| 282 |
-
loader = c
|
| 283 |
-
break
|
| 284 |
-
if loader is None:
|
| 285 |
-
found = list(data_root.rglob("hoho22k_2026_trainval.py"))
|
| 286 |
-
loader = found[0] if found else None
|
| 287 |
-
if loader is None:
|
| 288 |
-
raise FileNotFoundError("Cannot find loader script")
|
| 289 |
-
# Shard-level parallelism: each process handles a slice of tars
|
| 290 |
-
if args.shard_stride > 1:
|
| 291 |
-
tars = tars[args.shard_start::args.shard_stride]
|
| 292 |
-
print(f"Shard slice: start={args.shard_start} stride={args.shard_stride} -> {len(tars)} shards")
|
| 293 |
-
ds = load_dataset(str(loader), data_files={"train": tars},
|
| 294 |
-
streaming=True, trust_remote_code=True, split="train")
|
| 295 |
-
|
| 296 |
-
cfg = FuserConfig(depth_points_per_view=args.depth_per_view)
|
| 297 |
-
|
| 298 |
-
saved = 0
|
| 299 |
-
skipped = 0
|
| 300 |
-
t_start = time.perf_counter()
|
| 301 |
-
|
| 302 |
-
if args.workers > 0:
|
| 303 |
-
# Parallel: collect samples into batches, process in worker pool
|
| 304 |
-
# Note: HF streaming datasets can't be shared across workers, so we
|
| 305 |
-
# iterate in the main thread and dispatch processing to workers.
|
| 306 |
-
with ProcessPoolExecutor(max_workers=args.workers) as pool:
|
| 307 |
-
futures = {}
|
| 308 |
-
for i, sample in enumerate(ds):
|
| 309 |
-
if args.limit > 0 and i >= args.limit:
|
| 310 |
-
break
|
| 311 |
-
oid = sample.get("order_id", "unknown")
|
| 312 |
-
if oid in existing_ids:
|
| 313 |
-
skipped += 1
|
| 314 |
-
continue
|
| 315 |
-
future = pool.submit(_process_one, sample, cfg)
|
| 316 |
-
futures[future] = i
|
| 317 |
-
|
| 318 |
-
# Drain completed futures to bound memory
|
| 319 |
-
if len(futures) >= args.workers * 4:
|
| 320 |
-
done = [f for f in futures if f.done()]
|
| 321 |
-
for f in done:
|
| 322 |
-
result = f.result()
|
| 323 |
-
del futures[f]
|
| 324 |
-
if result is None:
|
| 325 |
-
skipped += 1
|
| 326 |
-
continue
|
| 327 |
-
order_id, data = result
|
| 328 |
-
torch.save(data, out_dir / f"{order_id}.pt")
|
| 329 |
-
saved += 1
|
| 330 |
-
if saved % 50 == 0:
|
| 331 |
-
elapsed = time.perf_counter() - t_start
|
| 332 |
-
print(f"Saved {saved} (skipped {skipped}) "
|
| 333 |
-
f"[{saved / elapsed:.1f} samples/s]")
|
| 334 |
-
|
| 335 |
-
# Drain remaining
|
| 336 |
-
for f in as_completed(futures):
|
| 337 |
-
result = f.result()
|
| 338 |
-
if result is None:
|
| 339 |
-
skipped += 1
|
| 340 |
-
continue
|
| 341 |
-
order_id, data = result
|
| 342 |
-
torch.save(data, out_dir / f"{order_id}.pt")
|
| 343 |
-
saved += 1
|
| 344 |
-
else:
|
| 345 |
-
# Sequential
|
| 346 |
-
for i, sample in enumerate(ds):
|
| 347 |
-
if args.limit > 0 and i >= args.limit:
|
| 348 |
-
break
|
| 349 |
-
oid = sample.get("order_id", "unknown")
|
| 350 |
-
if oid in existing_ids:
|
| 351 |
-
skipped += 1
|
| 352 |
-
continue
|
| 353 |
-
|
| 354 |
-
result = _process_one(sample, cfg)
|
| 355 |
-
if result is None:
|
| 356 |
-
skipped += 1
|
| 357 |
-
continue
|
| 358 |
-
order_id, data = result
|
| 359 |
-
torch.save(data, out_dir / f"{order_id}.pt")
|
| 360 |
-
saved += 1
|
| 361 |
-
|
| 362 |
-
if saved % 50 == 0:
|
| 363 |
-
elapsed = time.perf_counter() - t_start
|
| 364 |
-
print(f"Saved {saved} (skipped {skipped}) "
|
| 365 |
-
f"[{saved / elapsed:.1f} samples/s]")
|
| 366 |
-
|
| 367 |
-
elapsed = time.perf_counter() - t_start
|
| 368 |
-
print(f"Done. Saved {saved}, skipped {skipped} in {elapsed:.0f}s "
|
| 369 |
-
f"({saved / elapsed:.1f} samples/s)")
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
if __name__ == "__main__":
|
| 373 |
-
main()
|
|
|
|
| 23 |
"""
|
| 24 |
from __future__ import annotations
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
import numpy as np
|
|
|
|
| 27 |
|
| 28 |
from .point_fusion import (
|
|
|
|
| 29 |
GEST_ID_TO_NAME, ADE_ID_TO_NAME, NUM_GEST,
|
| 30 |
)
|
| 31 |
|
|
|
|
| 176 |
return center.astype(np.float32), np.float32(scale)
|
| 177 |
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/color_mappings.py
CHANGED
|
@@ -181,29 +181,3 @@ ade20k_color_mapping = {
|
|
| 181 |
'clock': (102, 255, 0),
|
| 182 |
'flag': (92, 0, 255),
|
| 183 |
}
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
EDGE_CLASSES = {'cornice_return': 0,
|
| 187 |
-
'cornice_strip': 1,
|
| 188 |
-
'eave': 2,
|
| 189 |
-
'flashing': 3,
|
| 190 |
-
'hip': 4,
|
| 191 |
-
'rake': 5,
|
| 192 |
-
'ridge': 6,
|
| 193 |
-
'step_flashing': 7,
|
| 194 |
-
'transition_line': 8,
|
| 195 |
-
'valley': 9}
|
| 196 |
-
EDGE_CLASSES_BY_ID = {v: k for k, v in EDGE_CLASSES.items()}
|
| 197 |
-
|
| 198 |
-
edge_color_mapping = {
|
| 199 |
-
'cornice_return': (215, 62, 138),
|
| 200 |
-
'cornice_strip': (235, 88, 48),
|
| 201 |
-
'eave': (54, 243, 63),
|
| 202 |
-
"flashing": (162, 162, 32),
|
| 203 |
-
'hip': (8, 89, 52),
|
| 204 |
-
'rake': (13, 94, 47),
|
| 205 |
-
'ridge': (214, 251, 248),
|
| 206 |
-
"step_flashing": (169, 255, 219),
|
| 207 |
-
'transition_line': (200,0,50),
|
| 208 |
-
'valley': (85, 27, 65),
|
| 209 |
-
}
|
|
|
|
| 181 |
'clock': (102, 255, 0),
|
| 182 |
'flag': (92, 0, 255),
|
| 183 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/data.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Data loading for pre-sampled HF datasets.
|
| 2 |
|
| 3 |
-
Expects pre-sampled npz blobs with xyz_norm
|
|
|
|
| 4 |
Use make_sampled_cache.py to produce these from full point clouds.
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
|
@@ -12,7 +13,7 @@ import torch
|
|
| 12 |
|
| 13 |
from .tokenizer import EdgeDepthSequenceConfig
|
| 14 |
|
| 15 |
-
# Default token budget (
|
| 16 |
SEQ_LEN = 2048
|
| 17 |
COLMAP_POINTS = 1536
|
| 18 |
DEPTH_POINTS = 512
|
|
@@ -130,9 +131,6 @@ def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False)
|
|
| 130 |
result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
|
| 131 |
if "vote_frac" in d:
|
| 132 |
result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
|
| 133 |
-
if "gt_edge_classes" in d:
|
| 134 |
-
result["gt_edge_classes"] = torch.as_tensor(
|
| 135 |
-
np.asarray(d["gt_edge_classes"], dtype=np.int64), dtype=torch.long)
|
| 136 |
return result
|
| 137 |
|
| 138 |
|
|
@@ -161,14 +159,6 @@ def collate(batch):
|
|
| 161 |
f"Field '{field}' present in some batch samples but missing in "
|
| 162 |
f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
|
| 163 |
out[field] = torch.stack([d[field] for d in batch])
|
| 164 |
-
# gt_edge_classes: variable length per sample (like gt_segments), keep as list
|
| 165 |
-
if any("gt_edge_classes" in d for d in batch):
|
| 166 |
-
missing = [i for i, d in enumerate(batch) if "gt_edge_classes" not in d]
|
| 167 |
-
if missing:
|
| 168 |
-
raise KeyError(
|
| 169 |
-
f"Field 'gt_edge_classes' present in some batch samples but missing in "
|
| 170 |
-
f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
|
| 171 |
-
out["gt_edge_classes"] = [d["gt_edge_classes"] for d in batch]
|
| 172 |
return out
|
| 173 |
|
| 174 |
|
|
|
|
| 1 |
"""Data loading for pre-sampled HF datasets.
|
| 2 |
|
| 3 |
+
Expects pre-sampled npz blobs with xyz_norm (not full PCD).
|
| 4 |
+
Supports both 2048-point and 4096-point datasets.
|
| 5 |
Use make_sampled_cache.py to produce these from full point clouds.
|
| 6 |
"""
|
| 7 |
from __future__ import annotations
|
|
|
|
| 13 |
|
| 14 |
from .tokenizer import EdgeDepthSequenceConfig
|
| 15 |
|
| 16 |
+
# Default token budget (for 2048-point datasets; 4096 uses 3072/1024)
|
| 17 |
SEQ_LEN = 2048
|
| 18 |
COLMAP_POINTS = 1536
|
| 19 |
DEPTH_POINTS = 512
|
|
|
|
| 131 |
result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
|
| 132 |
if "vote_frac" in d:
|
| 133 |
result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
| 134 |
return result
|
| 135 |
|
| 136 |
|
|
|
|
| 159 |
f"Field '{field}' present in some batch samples but missing in "
|
| 160 |
f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
|
| 161 |
out[field] = torch.stack([d[field] for d in batch])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
return out
|
| 163 |
|
| 164 |
|
s23dr_2026_example/losses.py
CHANGED
|
@@ -5,7 +5,6 @@ import torch
|
|
| 5 |
|
| 6 |
from .varifold import varifold_loss_batch
|
| 7 |
from .sinkhorn import batched_sinkhorn_loss
|
| 8 |
-
from .soft_hss_loss import batched_sinkhorn_vertex_f1, batched_soft_hss_v2
|
| 9 |
|
| 10 |
# Varifold config
|
| 11 |
VARIANT = "simpson3"
|
|
@@ -18,21 +17,12 @@ VARIFOLD_CROSS_ONLY = False # Set to True to drop self-energy (avoids O(S^2) bl
|
|
| 18 |
SINKHORN_EPS = 0.05
|
| 19 |
SINKHORN_ITERS = 10
|
| 20 |
|
| 21 |
-
# Distance thresholds in meters (divided by per-scene scale at runtime)
|
| 22 |
-
VERTEX_THRESH_M = 0.5 # vertex match threshold (mirrors real HSS)
|
| 23 |
-
TUBE_RADIUS_M = 0.5 # tube IoU radius (mirrors real HSS)
|
| 24 |
-
|
| 25 |
# Sinkhorn dustbin cost: controls the OT "not matching" penalty.
|
| 26 |
# Like tau, this is an OT behavior parameter, NOT a physical distance.
|
| 27 |
# Must be comparable to typical matching costs in normalized space (~0.1).
|
| 28 |
# Do NOT divide by scale.
|
| 29 |
SINKHORN_DUSTBIN = 0.1
|
| 30 |
|
| 31 |
-
# Sigmoid temperature: controls gradient smoothness, NOT a distance threshold.
|
| 32 |
-
# Must stay large enough in normalized space to provide useful gradients.
|
| 33 |
-
# Do NOT divide by scale (unlike the thresholds above).
|
| 34 |
-
SIGMOID_TAU = 0.05
|
| 35 |
-
|
| 36 |
MAX_GT = 64 # fixed pad size for compile-friendly shapes
|
| 37 |
|
| 38 |
# Precomputed constants (created once on first call)
|
|
@@ -65,7 +55,7 @@ def pad_gt_fixed(gt_list, device, dtype):
|
|
| 65 |
|
| 66 |
|
| 67 |
def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
| 68 |
-
sigmas, alphas, varifold_w
|
| 69 |
"""Pure tensor loss -- no Python control flow, no boolean indexing."""
|
| 70 |
has_gt = (gt_lengths > 0).float()
|
| 71 |
|
|
@@ -78,77 +68,36 @@ def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
|
| 78 |
v = loss_batch / gt_lengths.clamp(min=1.0)
|
| 79 |
v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
pred_segments, gt_pad, gt_mask, thresh=thresh, tau=SIGMOID_TAU)
|
| 84 |
-
f1 = (f1 * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 85 |
-
|
| 86 |
-
total = varifold_w * v + vertex_f1_w * f1
|
| 87 |
-
return total, v, f1
|
| 88 |
|
| 89 |
|
| 90 |
# Will be replaced with compiled version on CUDA
|
| 91 |
_loss_fn = _loss_inner
|
| 92 |
|
| 93 |
|
| 94 |
-
def _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales):
|
| 95 |
-
"""Auxiliary BCE loss: train conf to predict whether each segment matches GT.
|
| 96 |
-
|
| 97 |
-
Computes per-segment min-distance to GT, creates soft match target via
|
| 98 |
-
sigmoid thresholding, and returns BCE(sigmoid(conf), target).
|
| 99 |
-
"""
|
| 100 |
-
B, S = pred_segments.shape[:2]
|
| 101 |
-
# Decoupled cost: midpoint + direction + length (same as sinkhorn)
|
| 102 |
-
p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
|
| 103 |
-
g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
|
| 104 |
-
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
| 105 |
-
mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
|
| 106 |
-
d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
|
| 107 |
-
len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 108 |
-
len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 109 |
-
dir_p = half_p / len_p
|
| 110 |
-
dir_g = half_g / len_g
|
| 111 |
-
cos_angle = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
|
| 112 |
-
d_dir = 1.0 - cos_angle.abs()
|
| 113 |
-
d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
|
| 114 |
-
cost = d_mid + d_dir + d_len # [B, S, M]
|
| 115 |
-
|
| 116 |
-
# Mask invalid GT with high cost
|
| 117 |
-
cost = torch.where(gt_mask.unsqueeze(1), cost, cost.new_tensor(1e6))
|
| 118 |
-
min_dist = cost.min(dim=2).values # [B, S]
|
| 119 |
-
|
| 120 |
-
# Soft target: sigmoid((thresh - dist) / tau), in normalized space
|
| 121 |
-
thresh = VERTEX_THRESH_M / scales # [B]
|
| 122 |
-
target = torch.sigmoid((thresh[:, None] - min_dist) / SIGMOID_TAU)
|
| 123 |
-
|
| 124 |
-
return torch.nn.functional.binary_cross_entropy_with_logits(
|
| 125 |
-
conf_logits, target.detach(), reduction="mean")
|
| 126 |
-
|
| 127 |
-
|
| 128 |
def compute_loss(pred_segments, gt_list, scales, device,
|
| 129 |
-
varifold_w, sinkhorn_w,
|
| 130 |
endpoint_w=0.0,
|
| 131 |
-
conf_logits=None, conf_weight=0.0, conf_mode="
|
| 132 |
sinkhorn_eps=None, sinkhorn_iters=None,
|
| 133 |
sinkhorn_dustbin=None, conf_clamp_min=None):
|
| 134 |
"""Combined loss with fixed-size GT padding.
|
| 135 |
|
| 136 |
-
conf_mode: "
|
| 137 |
"""
|
| 138 |
if conf_logits is not None and conf_clamp_min is not None:
|
| 139 |
conf_logits = conf_logits.clamp(min=conf_clamp_min)
|
| 140 |
gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
|
| 141 |
c = _get_loss_constants(device, pred_segments.dtype)
|
| 142 |
|
| 143 |
-
total, v
|
| 144 |
pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
| 145 |
-
c["sigmas"], c["alphas"], varifold_w
|
| 146 |
|
| 147 |
terms = {}
|
| 148 |
if varifold_w > 0:
|
| 149 |
terms["varifold"] = v.detach()
|
| 150 |
-
if vertex_f1_w > 0:
|
| 151 |
-
terms["vertex_f1"] = f1.detach()
|
| 152 |
|
| 153 |
if sinkhorn_w > 0:
|
| 154 |
has_gt = (gt_lengths > 0).float()
|
|
@@ -171,28 +120,8 @@ def compute_loss(pred_segments, gt_list, scales, device,
|
|
| 171 |
total = total + sinkhorn_w * s
|
| 172 |
terms["sinkhorn"] = s.detach()
|
| 173 |
|
| 174 |
-
if soft_hss_w > 0:
|
| 175 |
-
has_gt = (gt_lengths > 0).float()
|
| 176 |
-
vert_thresh = VERTEX_THRESH_M / scales
|
| 177 |
-
edge_thresh = TUBE_RADIUS_M / scales
|
| 178 |
-
hss_loss = batched_soft_hss_v2(
|
| 179 |
-
pred_segments, gt_pad, gt_mask,
|
| 180 |
-
vert_thresh=vert_thresh, edge_thresh=edge_thresh, tau=SIGMOID_TAU)
|
| 181 |
-
hs = (hss_loss * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 182 |
-
total = total + soft_hss_w * hs
|
| 183 |
-
terms["soft_hss"] = hs.detach()
|
| 184 |
-
|
| 185 |
if conf_logits is not None and conf_weight > 0:
|
| 186 |
-
if conf_mode
|
| 187 |
-
# Explicit BCE supervision from nearest-GT distances
|
| 188 |
-
cl = _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales)
|
| 189 |
-
total = total + conf_weight * cl
|
| 190 |
-
terms["conf"] = cl.detach()
|
| 191 |
-
elif conf_mode in ("sinkhorn", "sinkhorn_detach"):
|
| 192 |
-
# Conf trained through sinkhorn transport gradients (via pred_mass).
|
| 193 |
-
# sinkhorn_detach: pred_mass uses detached conf, so OT can't push conf negative.
|
| 194 |
-
# Add count regularizer to prevent all-zero conf collapse.
|
| 195 |
-
# Normalized by S so magnitude doesn't depend on segment count.
|
| 196 |
conf_w = torch.sigmoid(conf_logits)
|
| 197 |
S = conf_logits.shape[1]
|
| 198 |
gt_counts = gt_mask.sum(dim=1).float()
|
|
@@ -200,29 +129,6 @@ def compute_loss(pred_segments, gt_list, scales, device,
|
|
| 200 |
reg = (((conf_sum - gt_counts) / S) ** 2).mean()
|
| 201 |
total = total + conf_weight * reg
|
| 202 |
terms["conf_reg"] = reg.detach()
|
| 203 |
-
elif conf_mode == "varifold":
|
| 204 |
-
# Conf-weighted varifold: weight each pred segment's contribution
|
| 205 |
-
# by sigmoid(conf). Low-conf segments contribute less to the loss.
|
| 206 |
-
# Needs regularizer to prevent all-zero conf collapse.
|
| 207 |
-
has_gt = (gt_lengths > 0).float()
|
| 208 |
-
conf_w = torch.sigmoid(conf_logits) # [B, S]
|
| 209 |
-
sigmas_eff = c["sigmas"] / scales[:, None]
|
| 210 |
-
vf_conf = varifold_loss_batch(
|
| 211 |
-
pred_segments, gt_pad, gt_mask=gt_mask,
|
| 212 |
-
variant=VARIANT, sigmas=sigmas_eff, alpha=c["alphas"],
|
| 213 |
-
len_pow=LEN_POW, pred_weights=conf_w,
|
| 214 |
-
)
|
| 215 |
-
vc = (vf_conf / gt_lengths.clamp(min=1.0))
|
| 216 |
-
vc = (vc * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 217 |
-
# Regularizer: penalize total conf being far from n_gt
|
| 218 |
-
# Normalized by S so magnitude doesn't depend on segment count
|
| 219 |
-
S = conf_logits.shape[1]
|
| 220 |
-
gt_counts = gt_mask.sum(dim=1).float() # [B]
|
| 221 |
-
conf_sum = conf_w.sum(dim=1) # [B]
|
| 222 |
-
reg = (((conf_sum - gt_counts) / S) ** 2).mean()
|
| 223 |
-
total = total + conf_weight * vc + 0.01 * reg
|
| 224 |
-
terms["conf_vf"] = vc.detach()
|
| 225 |
-
terms["conf_reg"] = reg.detach()
|
| 226 |
else:
|
| 227 |
raise ValueError(f"Unknown conf_mode: {conf_mode}")
|
| 228 |
|
|
@@ -234,14 +140,12 @@ def compute_loss(pred_segments, gt_list, scales, device,
|
|
| 234 |
B, S = pred_segments.shape[:2]
|
| 235 |
M = gt_pad.shape[1]
|
| 236 |
|
| 237 |
-
# Compute hard assignment via sinkhorn (detached
|
| 238 |
with torch.no_grad():
|
| 239 |
pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
|
| 240 |
sink_loss_for_assign = batched_sinkhorn_loss(
|
| 241 |
pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
|
| 242 |
pred_mass=pred_mass_ep)
|
| 243 |
-
# Re-run sinkhorn to get transport matrix for assignment
|
| 244 |
-
# (reuse the cost computation from batched_sinkhorn_loss internals)
|
| 245 |
p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
|
| 246 |
g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
|
| 247 |
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
|
|
|
| 5 |
|
| 6 |
from .varifold import varifold_loss_batch
|
| 7 |
from .sinkhorn import batched_sinkhorn_loss
|
|
|
|
| 8 |
|
| 9 |
# Varifold config
|
| 10 |
VARIANT = "simpson3"
|
|
|
|
| 17 |
SINKHORN_EPS = 0.05
|
| 18 |
SINKHORN_ITERS = 10
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Sinkhorn dustbin cost: controls the OT "not matching" penalty.
|
| 21 |
# Like tau, this is an OT behavior parameter, NOT a physical distance.
|
| 22 |
# Must be comparable to typical matching costs in normalized space (~0.1).
|
| 23 |
# Do NOT divide by scale.
|
| 24 |
SINKHORN_DUSTBIN = 0.1
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
MAX_GT = 64 # fixed pad size for compile-friendly shapes
|
| 27 |
|
| 28 |
# Precomputed constants (created once on first call)
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
| 58 |
+
sigmas, alphas, varifold_w):
|
| 59 |
"""Pure tensor loss -- no Python control flow, no boolean indexing."""
|
| 60 |
has_gt = (gt_lengths > 0).float()
|
| 61 |
|
|
|
|
| 68 |
v = loss_batch / gt_lengths.clamp(min=1.0)
|
| 69 |
v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 70 |
|
| 71 |
+
total = varifold_w * v
|
| 72 |
+
return total, v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
# Will be replaced with compiled version on CUDA
|
| 76 |
_loss_fn = _loss_inner
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def compute_loss(pred_segments, gt_list, scales, device,
|
| 80 |
+
varifold_w, sinkhorn_w,
|
| 81 |
endpoint_w=0.0,
|
| 82 |
+
conf_logits=None, conf_weight=0.0, conf_mode="sinkhorn",
|
| 83 |
sinkhorn_eps=None, sinkhorn_iters=None,
|
| 84 |
sinkhorn_dustbin=None, conf_clamp_min=None):
|
| 85 |
"""Combined loss with fixed-size GT padding.
|
| 86 |
|
| 87 |
+
conf_mode: "sinkhorn" = conf-weighted sinkhorn, "sinkhorn_detach" = detached conf.
|
| 88 |
"""
|
| 89 |
if conf_logits is not None and conf_clamp_min is not None:
|
| 90 |
conf_logits = conf_logits.clamp(min=conf_clamp_min)
|
| 91 |
gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
|
| 92 |
c = _get_loss_constants(device, pred_segments.dtype)
|
| 93 |
|
| 94 |
+
total, v = _loss_fn(
|
| 95 |
pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
| 96 |
+
c["sigmas"], c["alphas"], varifold_w)
|
| 97 |
|
| 98 |
terms = {}
|
| 99 |
if varifold_w > 0:
|
| 100 |
terms["varifold"] = v.detach()
|
|
|
|
|
|
|
| 101 |
|
| 102 |
if sinkhorn_w > 0:
|
| 103 |
has_gt = (gt_lengths > 0).float()
|
|
|
|
| 120 |
total = total + sinkhorn_w * s
|
| 121 |
terms["sinkhorn"] = s.detach()
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if conf_logits is not None and conf_weight > 0:
|
| 124 |
+
if conf_mode in ("sinkhorn", "sinkhorn_detach"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
conf_w = torch.sigmoid(conf_logits)
|
| 126 |
S = conf_logits.shape[1]
|
| 127 |
gt_counts = gt_mask.sum(dim=1).float()
|
|
|
|
| 129 |
reg = (((conf_sum - gt_counts) / S) ** 2).mean()
|
| 130 |
total = total + conf_weight * reg
|
| 131 |
terms["conf_reg"] = reg.detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
else:
|
| 133 |
raise ValueError(f"Unknown conf_mode: {conf_mode}")
|
| 134 |
|
|
|
|
| 140 |
B, S = pred_segments.shape[:2]
|
| 141 |
M = gt_pad.shape[1]
|
| 142 |
|
| 143 |
+
# Compute hard assignment via sinkhorn (detached -- matching is not trained)
|
| 144 |
with torch.no_grad():
|
| 145 |
pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
|
| 146 |
sink_loss_for_assign = batched_sinkhorn_loss(
|
| 147 |
pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
|
| 148 |
pred_mass=pred_mass_ep)
|
|
|
|
|
|
|
| 149 |
p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
|
| 150 |
g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
|
| 151 |
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
s23dr_2026_example/make_sampled_cache.py
CHANGED
|
@@ -24,21 +24,7 @@ the same points. Fine for now; better augmentation can be added later.
|
|
| 24 |
"""
|
| 25 |
from __future__ import annotations
|
| 26 |
|
| 27 |
-
import sys
|
| 28 |
-
from pathlib import Path as _Path
|
| 29 |
-
if __package__ is None or __package__ == "":
|
| 30 |
-
_here = _Path(__file__).resolve().parent
|
| 31 |
-
if str(_here.parent) not in sys.path:
|
| 32 |
-
sys.path.insert(0, str(_here.parent))
|
| 33 |
-
__package__ = _here.name
|
| 34 |
-
|
| 35 |
-
import argparse
|
| 36 |
-
import io
|
| 37 |
-
import time
|
| 38 |
-
from pathlib import Path
|
| 39 |
-
|
| 40 |
import numpy as np
|
| 41 |
-
import torch
|
| 42 |
|
| 43 |
|
| 44 |
# Priority sampling (same logic as train.py)
|
|
@@ -87,174 +73,3 @@ def _priority_sample(source, group_id, seq_len, colmap_quota, depth_quota):
|
|
| 87 |
return indices[:seq_len], mask
|
| 88 |
|
| 89 |
|
| 90 |
-
def process_sample(xyz, source, group_id, class_id, vis_src, vis_id,
|
| 91 |
-
center, scale, gt_v, gt_e, behind=None,
|
| 92 |
-
n_views_voted=None, vote_frac=None,
|
| 93 |
-
gt_edge_classes=None,
|
| 94 |
-
seq_len=2048, colmap_q=1536, depth_q=512):
|
| 95 |
-
"""Sample and normalize one scene. Returns dict of numpy arrays."""
|
| 96 |
-
indices, mask = _priority_sample(source, group_id, seq_len, colmap_q, depth_q)
|
| 97 |
-
xyz_norm = ((xyz[indices] - center) / scale).astype(np.float32)
|
| 98 |
-
gt_seg = np.stack([gt_v[gt_e[:, 0]], gt_v[gt_e[:, 1]]], axis=1)
|
| 99 |
-
gt_seg_norm = ((gt_seg - center) / scale).astype(np.float32)
|
| 100 |
-
|
| 101 |
-
result = {
|
| 102 |
-
"xyz_norm": xyz_norm,
|
| 103 |
-
"class_id": class_id[indices].astype(np.uint8),
|
| 104 |
-
"source": source[indices].astype(np.uint8),
|
| 105 |
-
"mask": mask,
|
| 106 |
-
"gt_segments": gt_seg_norm,
|
| 107 |
-
"scale": np.float32(scale),
|
| 108 |
-
"center": center.astype(np.float32),
|
| 109 |
-
"gt_vertices": gt_v.astype(np.float32),
|
| 110 |
-
"gt_edges": gt_e.astype(np.int32),
|
| 111 |
-
"visible_src": vis_src[indices].astype(np.uint8),
|
| 112 |
-
"visible_id": vis_id[indices].astype(np.int16),
|
| 113 |
-
}
|
| 114 |
-
if behind is not None:
|
| 115 |
-
result["behind"] = behind[indices].astype(np.int16)
|
| 116 |
-
if n_views_voted is not None:
|
| 117 |
-
result["n_views_voted"] = n_views_voted[indices].astype(np.uint8)
|
| 118 |
-
if vote_frac is not None:
|
| 119 |
-
result["vote_frac"] = vote_frac[indices].astype(np.float32)
|
| 120 |
-
if gt_edge_classes is not None:
|
| 121 |
-
if len(gt_edge_classes) != len(gt_e):
|
| 122 |
-
raise ValueError(
|
| 123 |
-
f"gt_edge_classes length {len(gt_edge_classes)} != "
|
| 124 |
-
f"gt_edges length {len(gt_e)}")
|
| 125 |
-
result["gt_edge_classes"] = gt_edge_classes.astype(np.int64)
|
| 126 |
-
return result
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def _load_edge_classes(path):
|
| 130 |
-
"""Load edge classifications lookup from npz file."""
|
| 131 |
-
if path is None:
|
| 132 |
-
return None
|
| 133 |
-
path = Path(path)
|
| 134 |
-
if not path.exists():
|
| 135 |
-
raise FileNotFoundError(f"Edge classifications file not found: {path}")
|
| 136 |
-
data = np.load(str(path), allow_pickle=False)
|
| 137 |
-
lookup = {k: data[k] for k in data.files}
|
| 138 |
-
print(f"Loaded edge classifications for {len(lookup)} orders from {path}")
|
| 139 |
-
return lookup
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def main():
|
| 143 |
-
p = argparse.ArgumentParser()
|
| 144 |
-
g = p.add_mutually_exclusive_group(required=True)
|
| 145 |
-
g.add_argument("--in-dir", help="Local directory of .pt files")
|
| 146 |
-
g.add_argument("--hf-repo", help="HuggingFace dataset repo (e.g. usm3d/s23dr-2026-cached_full_pcd)")
|
| 147 |
-
p.add_argument("--split", default="train", help="HF dataset split")
|
| 148 |
-
p.add_argument("--out-dir", required=True)
|
| 149 |
-
p.add_argument("--edge-classes", default=None,
|
| 150 |
-
help="Path to edge_classifications.npz from extract_edge_classes.py")
|
| 151 |
-
p.add_argument("--seq-len", type=int, default=2048)
|
| 152 |
-
p.add_argument("--colmap-quota", type=int, default=1536)
|
| 153 |
-
p.add_argument("--depth-quota", type=int, default=512)
|
| 154 |
-
p.add_argument("--seed", type=int, default=7)
|
| 155 |
-
args = p.parse_args()
|
| 156 |
-
|
| 157 |
-
out_dir = Path(args.out_dir)
|
| 158 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 159 |
-
np.random.seed(args.seed)
|
| 160 |
-
|
| 161 |
-
edge_cls_lookup = _load_edge_classes(args.edge_classes)
|
| 162 |
-
n_edge_matched, n_edge_missing = 0, 0
|
| 163 |
-
|
| 164 |
-
t_start = time.perf_counter()
|
| 165 |
-
done = 0
|
| 166 |
-
|
| 167 |
-
if args.in_dir:
|
| 168 |
-
# Local .pt files
|
| 169 |
-
files = sorted(Path(args.in_dir).glob("*.pt"))
|
| 170 |
-
print(f"Converting {len(files)} local .pt files...")
|
| 171 |
-
for f in files:
|
| 172 |
-
out_f = out_dir / (f.stem + ".npz")
|
| 173 |
-
if out_f.exists():
|
| 174 |
-
done += 1
|
| 175 |
-
continue
|
| 176 |
-
d = torch.load(f, weights_only=False)
|
| 177 |
-
behind = np.asarray(d["behind_gest_id"], np.int16) if "behind_gest_id" in d else None
|
| 178 |
-
n_vv = np.asarray(d["n_views_voted"], np.uint8) if "n_views_voted" in d else None
|
| 179 |
-
vf = np.asarray(d["vote_frac"], np.float32) if "vote_frac" in d else None
|
| 180 |
-
gt_ec = None
|
| 181 |
-
if edge_cls_lookup is not None:
|
| 182 |
-
order_id = f.stem
|
| 183 |
-
if order_id in edge_cls_lookup:
|
| 184 |
-
gt_ec = edge_cls_lookup[order_id]
|
| 185 |
-
n_edge_matched += 1
|
| 186 |
-
else:
|
| 187 |
-
n_edge_missing += 1
|
| 188 |
-
result = process_sample(
|
| 189 |
-
np.asarray(d["xyz"], np.float32),
|
| 190 |
-
np.asarray(d["source"], np.uint8),
|
| 191 |
-
np.asarray(d["group_id"], np.int8),
|
| 192 |
-
np.asarray(d["class_id"], np.uint8),
|
| 193 |
-
np.asarray(d["visible_src"], np.uint8),
|
| 194 |
-
np.asarray(d["visible_id"], np.int16),
|
| 195 |
-
np.asarray(d["center"], np.float32),
|
| 196 |
-
float(d["scale"]),
|
| 197 |
-
np.asarray(d["gt_vertices"], np.float32),
|
| 198 |
-
np.asarray(d["gt_edges"], np.int32),
|
| 199 |
-
behind=behind, n_views_voted=n_vv, vote_frac=vf,
|
| 200 |
-
gt_edge_classes=gt_ec,
|
| 201 |
-
seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
|
| 202 |
-
)
|
| 203 |
-
np.savez(out_f, **result)
|
| 204 |
-
done += 1
|
| 205 |
-
if done % 2000 == 0:
|
| 206 |
-
print(f" {done}/{len(files)} [{done/(time.perf_counter()-t_start):.0f}/s]")
|
| 207 |
-
else:
|
| 208 |
-
# HF dataset
|
| 209 |
-
from datasets import load_dataset
|
| 210 |
-
print(f"Loading {args.hf_repo} split={args.split}...")
|
| 211 |
-
ds = load_dataset(args.hf_repo, split=args.split)
|
| 212 |
-
print(f"Converting {len(ds)} samples...")
|
| 213 |
-
for i, sample in enumerate(ds):
|
| 214 |
-
order_id = sample["order_id"]
|
| 215 |
-
out_f = out_dir / f"{order_id}.npz"
|
| 216 |
-
if out_f.exists():
|
| 217 |
-
done += 1
|
| 218 |
-
continue
|
| 219 |
-
arrays = np.load(io.BytesIO(sample["data"]))
|
| 220 |
-
behind = arrays["behind_gest_id"] if "behind_gest_id" in arrays else None
|
| 221 |
-
n_vv = arrays["n_views_voted"] if "n_views_voted" in arrays else None
|
| 222 |
-
vf = arrays["vote_frac"] if "vote_frac" in arrays else None
|
| 223 |
-
gt_ec = None
|
| 224 |
-
if edge_cls_lookup is not None:
|
| 225 |
-
if order_id in edge_cls_lookup:
|
| 226 |
-
gt_ec = edge_cls_lookup[order_id]
|
| 227 |
-
n_edge_matched += 1
|
| 228 |
-
else:
|
| 229 |
-
n_edge_missing += 1
|
| 230 |
-
result = process_sample(
|
| 231 |
-
arrays["xyz"], arrays["source"], arrays["group_id"],
|
| 232 |
-
arrays["class_id"], arrays["visible_src"], arrays["visible_id"],
|
| 233 |
-
arrays["center"], float(arrays["scale"]),
|
| 234 |
-
arrays["gt_vertices"], arrays["gt_edges"],
|
| 235 |
-
behind=behind, n_views_voted=n_vv, vote_frac=vf,
|
| 236 |
-
gt_edge_classes=gt_ec,
|
| 237 |
-
seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
|
| 238 |
-
)
|
| 239 |
-
np.savez(out_f, **result)
|
| 240 |
-
done += 1
|
| 241 |
-
if done % 2000 == 0:
|
| 242 |
-
print(f" {done}/{len(ds)} [{done/(time.perf_counter()-t_start):.0f}/s]")
|
| 243 |
-
|
| 244 |
-
elapsed = time.perf_counter() - t_start
|
| 245 |
-
print(f"Done: {done} files in {elapsed:.0f}s ({done/max(1,elapsed):.0f}/s)")
|
| 246 |
-
|
| 247 |
-
if edge_cls_lookup is not None:
|
| 248 |
-
print(f"Edge classifications: {n_edge_matched} matched, {n_edge_missing} missing")
|
| 249 |
-
|
| 250 |
-
# Report sizes
|
| 251 |
-
import os
|
| 252 |
-
npz_files = list(out_dir.glob("*.npz"))
|
| 253 |
-
if npz_files:
|
| 254 |
-
sizes = [os.path.getsize(f) for f in npz_files[:100]]
|
| 255 |
-
print(f"Avg file size: {np.mean(sizes)/1024:.0f}KB")
|
| 256 |
-
print(f"Est total: {np.mean(sizes)*len(npz_files)/1e9:.1f}GB")
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
if __name__ == "__main__":
|
| 260 |
-
main()
|
|
|
|
| 24 |
"""
|
| 25 |
from __future__ import annotations
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
import numpy as np
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
# Priority sampling (same logic as train.py)
|
|
|
|
| 73 |
return indices[:seq_len], mask
|
| 74 |
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/model.py
CHANGED
|
@@ -239,7 +239,6 @@ class TokenTransformerSegments(nn.Module):
|
|
| 239 |
self.segments = segments
|
| 240 |
self.out_vertices = segments * 2
|
| 241 |
self.segment_param = segment_param
|
| 242 |
-
self.length_floor = length_floor
|
| 243 |
self.decoder_input_xattn = decoder_input_xattn
|
| 244 |
norm_class = norm_class or nn.LayerNorm
|
| 245 |
|
|
@@ -428,167 +427,7 @@ class SelfAttentionEncoderLayer(nn.Module):
|
|
| 428 |
|
| 429 |
|
| 430 |
# ---------------------------------------------------------------------------
|
| 431 |
-
#
|
| 432 |
-
# ---------------------------------------------------------------------------
|
| 433 |
-
|
| 434 |
-
class TransformerSegments(nn.Module):
|
| 435 |
-
"""Standard transformer encoder + cross-attention segment decoder.
|
| 436 |
-
|
| 437 |
-
Architecture:
|
| 438 |
-
Input tokens [B, T, D]
|
| 439 |
-
|
|
| 440 |
-
v
|
| 441 |
-
input_proj: Linear -> GELU -> Linear -> Norm => [B, T, hidden]
|
| 442 |
-
|
|
| 443 |
-
v
|
| 444 |
-
N SelfAttentionEncoderLayers (self-attn over all T tokens)
|
| 445 |
-
|
|
| 446 |
-
v
|
| 447 |
-
Segment decoder (same as Perceiver version):
|
| 448 |
-
M SegmentDecoderLayers (queries cross-attend to encoded tokens)
|
| 449 |
-
|
|
| 450 |
-
v
|
| 451 |
-
segment_head -> endpoints [B, S, 2, 3] (midpoint_halfvec or midpoint_dir_len)
|
| 452 |
-
"""
|
| 453 |
-
|
| 454 |
-
def __init__(
|
| 455 |
-
self,
|
| 456 |
-
segments: int = 32,
|
| 457 |
-
in_dim: int = 128,
|
| 458 |
-
hidden: int = 128,
|
| 459 |
-
num_heads: int = 4,
|
| 460 |
-
kv_heads_cross: int | None = 2,
|
| 461 |
-
kv_heads_self: int | None = 0,
|
| 462 |
-
dim_feedforward: int = 256,
|
| 463 |
-
dropout: float = 0.01,
|
| 464 |
-
encoder_layers: int = 4,
|
| 465 |
-
decoder_layers: int = 2,
|
| 466 |
-
norm_class=None,
|
| 467 |
-
activation: str = "gelu",
|
| 468 |
-
segment_conf: bool = False,
|
| 469 |
-
segment_param: str = "midpoint_halfvec",
|
| 470 |
-
length_floor: float = 0.0,
|
| 471 |
-
decoder_input_xattn: bool = False,
|
| 472 |
-
qk_norm: bool = False,
|
| 473 |
-
qk_norm_type: str = "l2",
|
| 474 |
-
):
|
| 475 |
-
super().__init__()
|
| 476 |
-
self.segments = segments
|
| 477 |
-
self.out_vertices = segments * 2
|
| 478 |
-
self.segment_param = segment_param
|
| 479 |
-
self.length_floor = length_floor
|
| 480 |
-
norm_class = norm_class or nn.LayerNorm
|
| 481 |
-
|
| 482 |
-
if kv_heads_cross is not None and kv_heads_cross <= 0:
|
| 483 |
-
kv_heads_cross = None
|
| 484 |
-
if kv_heads_self is not None and kv_heads_self <= 0:
|
| 485 |
-
kv_heads_self = None
|
| 486 |
-
|
| 487 |
-
# -- Input projection --
|
| 488 |
-
self.input_proj = nn.Sequential(
|
| 489 |
-
nn.Linear(in_dim, dim_feedforward),
|
| 490 |
-
nn.GELU(),
|
| 491 |
-
nn.Linear(dim_feedforward, hidden),
|
| 492 |
-
norm_class(hidden),
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
# -- Self-attention encoder --
|
| 496 |
-
self.encoder_layers = nn.ModuleList([
|
| 497 |
-
SelfAttentionEncoderLayer(
|
| 498 |
-
d_model=hidden,
|
| 499 |
-
num_heads=num_heads,
|
| 500 |
-
dim_ff=dim_feedforward,
|
| 501 |
-
dropout=dropout,
|
| 502 |
-
activation=activation,
|
| 503 |
-
kv_heads=kv_heads_self,
|
| 504 |
-
norm_class=norm_class,
|
| 505 |
-
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 506 |
-
)
|
| 507 |
-
for _ in range(encoder_layers)
|
| 508 |
-
])
|
| 509 |
-
|
| 510 |
-
# -- Segment decoder (same structure as Perceiver version) --
|
| 511 |
-
# Note: for transformer arch, decoder_input_xattn is ignored because
|
| 512 |
-
# the decoder already cross-attends to the full encoded token sequence.
|
| 513 |
-
self.query_embed = nn.Embedding(segments, hidden)
|
| 514 |
-
self.decoder_layers = nn.ModuleList([
|
| 515 |
-
SegmentDecoderLayer(
|
| 516 |
-
d_model=hidden,
|
| 517 |
-
num_heads=num_heads,
|
| 518 |
-
dim_ff=dim_feedforward,
|
| 519 |
-
dropout=dropout,
|
| 520 |
-
activation=activation,
|
| 521 |
-
kv_heads_cross=kv_heads_cross,
|
| 522 |
-
kv_heads_self=kv_heads_self,
|
| 523 |
-
norm_class=norm_class,
|
| 524 |
-
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 525 |
-
)
|
| 526 |
-
for _ in range(decoder_layers)
|
| 527 |
-
])
|
| 528 |
-
|
| 529 |
-
# -- Output head (shared logic with Perceiver version) --
|
| 530 |
-
if segment_param == "midpoint_dir_len":
|
| 531 |
-
self.segment_head = nn.Linear(hidden, 7) # mid(3) + dir(3) + len(1)
|
| 532 |
-
else:
|
| 533 |
-
self.segment_head = nn.Linear(hidden, 6) # mid(3) + half(3)
|
| 534 |
-
self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3))
|
| 535 |
-
|
| 536 |
-
nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3)
|
| 537 |
-
if self.segment_head.bias is not None:
|
| 538 |
-
nn.init.zeros_(self.segment_head.bias)
|
| 539 |
-
if segment_param == "midpoint_dir_len":
|
| 540 |
-
# sigmoid(-2.2) ~ 0.1 default length in normalized space (~3m)
|
| 541 |
-
self.segment_head.bias.data[6] = -2.2
|
| 542 |
-
nn.init.normal_(self.query_offsets, mean=0.0, std=0.05)
|
| 543 |
-
|
| 544 |
-
self.segment_conf = segment_conf
|
| 545 |
-
if segment_conf:
|
| 546 |
-
self.conf_head = nn.Linear(hidden, 1)
|
| 547 |
-
nn.init.zeros_(self.conf_head.bias)
|
| 548 |
-
|
| 549 |
-
def forward(
|
| 550 |
-
self,
|
| 551 |
-
tokens: torch.Tensor,
|
| 552 |
-
mask: torch.Tensor | None = None,
|
| 553 |
-
) -> dict[str, torch.Tensor | list]:
|
| 554 |
-
B = tokens.shape[0]
|
| 555 |
-
|
| 556 |
-
src = self.input_proj(tokens)
|
| 557 |
-
pad_mask = ~mask.bool() if mask is not None else None
|
| 558 |
-
|
| 559 |
-
# Encode: self-attention over all tokens
|
| 560 |
-
for layer in self.encoder_layers:
|
| 561 |
-
src = layer(src, key_padding_mask=pad_mask)
|
| 562 |
-
|
| 563 |
-
# Decode: segment queries cross-attend to encoded tokens
|
| 564 |
-
queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)
|
| 565 |
-
for layer in self.decoder_layers:
|
| 566 |
-
queries = layer(queries, src)
|
| 567 |
-
|
| 568 |
-
# Predict segments -> endpoints
|
| 569 |
-
if self.segment_param == "midpoint_dir_len":
|
| 570 |
-
raw = self.segment_head(queries) # [B, S, 7]
|
| 571 |
-
mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0)
|
| 572 |
-
direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1)
|
| 573 |
-
length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1
|
| 574 |
-
half = direction * length * 0.5
|
| 575 |
-
else:
|
| 576 |
-
raw = self.segment_head(queries).view(B, self.segments, 2, 3)
|
| 577 |
-
raw = raw + self.query_offsets.unsqueeze(0)
|
| 578 |
-
mid, half = raw[:, :, 0], raw[:, :, 1]
|
| 579 |
-
seg_params = torch.stack([mid - half, mid + half], dim=2)
|
| 580 |
-
|
| 581 |
-
vertices = seg_params.reshape(B, self.out_vertices, 3)
|
| 582 |
-
edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)]
|
| 583 |
-
|
| 584 |
-
out = {"vertices": vertices, "segments": seg_params, "edges": edges}
|
| 585 |
-
if self.segment_conf:
|
| 586 |
-
out["conf"] = self.conf_head(queries).squeeze(-1)
|
| 587 |
-
return out
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
# ---------------------------------------------------------------------------
|
| 591 |
-
# End-to-end model: tokenizer embeddings + transformer/perceiver
|
| 592 |
# ---------------------------------------------------------------------------
|
| 593 |
|
| 594 |
class EdgeDepthSegmentsModel(nn.Module):
|
|
@@ -648,25 +487,9 @@ class EdgeDepthSegmentsModel(nn.Module):
|
|
| 648 |
)
|
| 649 |
|
| 650 |
if arch == "transformer":
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
hidden=hidden,
|
| 655 |
-
num_heads=num_heads,
|
| 656 |
-
kv_heads_cross=kv_heads_cross,
|
| 657 |
-
kv_heads_self=kv_heads_self,
|
| 658 |
-
dim_feedforward=dim_feedforward,
|
| 659 |
-
dropout=dropout,
|
| 660 |
-
encoder_layers=encoder_layers,
|
| 661 |
-
decoder_layers=decoder_layers,
|
| 662 |
-
norm_class=norm_class,
|
| 663 |
-
activation=activation,
|
| 664 |
-
segment_conf=segment_conf,
|
| 665 |
-
segment_param=segment_param,
|
| 666 |
-
length_floor=length_floor,
|
| 667 |
-
decoder_input_xattn=decoder_input_xattn,
|
| 668 |
-
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 669 |
-
)
|
| 670 |
else:
|
| 671 |
self.segmenter = TokenTransformerSegments(
|
| 672 |
segments=segments,
|
|
|
|
| 239 |
self.segments = segments
|
| 240 |
self.out_vertices = segments * 2
|
| 241 |
self.segment_param = segment_param
|
|
|
|
| 242 |
self.decoder_input_xattn = decoder_input_xattn
|
| 243 |
norm_class = norm_class or nn.LayerNorm
|
| 244 |
|
|
|
|
| 427 |
|
| 428 |
|
| 429 |
# ---------------------------------------------------------------------------
|
| 430 |
+
# End-to-end model: tokenizer embeddings + perceiver
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
# ---------------------------------------------------------------------------
|
| 432 |
|
| 433 |
class EdgeDepthSegmentsModel(nn.Module):
|
|
|
|
| 487 |
)
|
| 488 |
|
| 489 |
if arch == "transformer":
|
| 490 |
+
raise ValueError(
|
| 491 |
+
"arch='transformer' is no longer supported. "
|
| 492 |
+
"TransformerSegments has been removed; use arch='perceiver'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
else:
|
| 494 |
self.segmenter = TokenTransformerSegments(
|
| 495 |
segments=segments,
|
s23dr_2026_example/sinkhorn.py
CHANGED
|
@@ -10,26 +10,6 @@ to get useful gradients early and precise matching late.
|
|
| 10 |
import torch
|
| 11 |
|
| 12 |
|
| 13 |
-
def segment_pair_cost(pred_segments: torch.Tensor, gt_segments: torch.Tensor) -> torch.Tensor:
|
| 14 |
-
"""Cost between pred and GT segments: midpoint + direction + length (decoupled).
|
| 15 |
-
pred_segments: [N, 2, 3], gt_segments: [M, 2, 3] -> [N, M]
|
| 16 |
-
"""
|
| 17 |
-
p0, p1 = pred_segments[:, 0], pred_segments[:, 1]
|
| 18 |
-
g0, g1 = gt_segments[:, 0], gt_segments[:, 1]
|
| 19 |
-
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
| 20 |
-
mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
|
| 21 |
-
d_mid = torch.cdist(mid_p, mid_g)
|
| 22 |
-
len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 23 |
-
len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 24 |
-
dir_p = half_p / len_p
|
| 25 |
-
dir_g = half_g / len_g
|
| 26 |
-
cos_angle = (dir_p[:, None, :] * dir_g[None, :, :]).sum(dim=-1)
|
| 27 |
-
d_dir = 1.0 - cos_angle.abs()
|
| 28 |
-
d_len = (len_p[:, None, :] - len_g[None, :, :]).squeeze(-1).abs()
|
| 29 |
-
return d_mid + d_dir + d_len
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
def batched_sinkhorn_loss(
|
| 34 |
pred_segments: torch.Tensor,
|
| 35 |
gt_pad: torch.Tensor,
|
|
@@ -144,38 +124,3 @@ def batched_sinkhorn_loss(
|
|
| 144 |
|
| 145 |
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
|
| 146 |
return (transport * cost_pad).sum(dim=(1, 2)) # [B]
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
# Keep the per-sample version for compatibility
|
| 150 |
-
def sinkhorn_segment_loss(
|
| 151 |
-
pred_segments: torch.Tensor,
|
| 152 |
-
gt_segments: torch.Tensor,
|
| 153 |
-
eps: float,
|
| 154 |
-
iters: int,
|
| 155 |
-
dustbin_cost: float,
|
| 156 |
-
pred_mass: torch.Tensor | None = None,
|
| 157 |
-
) -> torch.Tensor:
|
| 158 |
-
if pred_segments.numel() == 0 or gt_segments.numel() == 0:
|
| 159 |
-
return pred_segments.new_tensor(dustbin_cost)
|
| 160 |
-
cost = segment_pair_cost(pred_segments, gt_segments)
|
| 161 |
-
n, m = cost.shape
|
| 162 |
-
if n == 0 or m == 0:
|
| 163 |
-
return cost.new_tensor(dustbin_cost)
|
| 164 |
-
cost_pad = torch.full((n + 1, m + 1), dustbin_cost, device=cost.device, dtype=cost.dtype)
|
| 165 |
-
cost_pad[:n, :m] = cost
|
| 166 |
-
cost_pad[-1, -1] = 0.0
|
| 167 |
-
denom = float(n + m)
|
| 168 |
-
a = torch.full((n + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
|
| 169 |
-
b = torch.full((m + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
|
| 170 |
-
a[-1] = m / denom
|
| 171 |
-
b[-1] = n / denom
|
| 172 |
-
log_a = torch.log(a + 1e-9)
|
| 173 |
-
log_b = torch.log(b + 1e-9)
|
| 174 |
-
log_k = -cost_pad / eps
|
| 175 |
-
log_u = torch.zeros_like(a)
|
| 176 |
-
log_v = torch.zeros_like(b)
|
| 177 |
-
for _ in range(iters):
|
| 178 |
-
log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
|
| 179 |
-
log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
|
| 180 |
-
transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
|
| 181 |
-
return torch.sum(transport * cost_pad)
|
|
|
|
| 10 |
import torch
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def batched_sinkhorn_loss(
|
| 14 |
pred_segments: torch.Tensor,
|
| 15 |
gt_pad: torch.Tensor,
|
|
|
|
| 124 |
|
| 125 |
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
|
| 126 |
return (transport * cost_pad).sum(dim=(1, 2)) # [B]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/soft_hss_loss.py
DELETED
|
@@ -1,507 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def _softmin(values: torch.Tensor, dim: int, tau: float) -> torch.Tensor:
|
| 5 |
-
tau_t = torch.as_tensor(tau, device=values.device, dtype=values.dtype).clamp_min(1e-8)
|
| 6 |
-
return -tau_t * torch.logsumexp(-values / tau_t, dim=dim)
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def point_segment_distance_squared(
|
| 10 |
-
points: torch.Tensor,
|
| 11 |
-
seg_a: torch.Tensor,
|
| 12 |
-
seg_b: torch.Tensor,
|
| 13 |
-
eps: float = 1e-9,
|
| 14 |
-
) -> torch.Tensor:
|
| 15 |
-
"""
|
| 16 |
-
points: (P,3)
|
| 17 |
-
seg_a/seg_b: (S,3)
|
| 18 |
-
returns dist2: (P,S)
|
| 19 |
-
"""
|
| 20 |
-
ab = seg_b - seg_a # (S,3)
|
| 21 |
-
ab2 = (ab * ab).sum(dim=-1).clamp_min(eps) # (S,)
|
| 22 |
-
ap = points[:, None, :] - seg_a[None, :, :] # (P,S,3)
|
| 23 |
-
t = (ap * ab[None, :, :]).sum(dim=-1) / ab2[None, :] # (P,S)
|
| 24 |
-
t = t.clamp(0.0, 1.0)
|
| 25 |
-
closest = seg_a[None, :, :] + t[:, :, None] * ab[None, :, :]
|
| 26 |
-
diff = points[:, None, :] - closest
|
| 27 |
-
return (diff * diff).sum(dim=-1)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def distance_to_segments(
|
| 31 |
-
points: torch.Tensor,
|
| 32 |
-
segments: torch.Tensor,
|
| 33 |
-
eps: float = 1e-9,
|
| 34 |
-
) -> torch.Tensor:
|
| 35 |
-
"""
|
| 36 |
-
points: (P,3)
|
| 37 |
-
segments: (S,2,3)
|
| 38 |
-
returns min distance: (P,)
|
| 39 |
-
"""
|
| 40 |
-
a = segments[:, 0]
|
| 41 |
-
b = segments[:, 1]
|
| 42 |
-
dist2 = point_segment_distance_squared(points, a, b, eps=eps)
|
| 43 |
-
return torch.sqrt(dist2.min(dim=1).values + eps)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def soft_vertex_f1(
|
| 47 |
-
pred_vertices: torch.Tensor,
|
| 48 |
-
gt_vertices: torch.Tensor,
|
| 49 |
-
thresh: float,
|
| 50 |
-
tau: float = 0.05,
|
| 51 |
-
softmin_tau: float = 0.05,
|
| 52 |
-
eps: float = 1e-8,
|
| 53 |
-
) -> torch.Tensor:
|
| 54 |
-
"""
|
| 55 |
-
Soft surrogate for the Hungarian-thresholded corner F1 used by HSS.
|
| 56 |
-
|
| 57 |
-
Uses (soft) nearest-neighbor distances and a sigmoid threshold.
|
| 58 |
-
"""
|
| 59 |
-
if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
|
| 60 |
-
return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
|
| 61 |
-
|
| 62 |
-
pred = pred_vertices
|
| 63 |
-
gt = gt_vertices
|
| 64 |
-
|
| 65 |
-
diff = pred[:, None, :] - gt[None, :, :]
|
| 66 |
-
dist = torch.sqrt((diff * diff).sum(dim=-1) + eps) # (P,G)
|
| 67 |
-
|
| 68 |
-
d_pred = _softmin(dist, dim=1, tau=softmin_tau) # (P,)
|
| 69 |
-
d_gt = _softmin(dist, dim=0, tau=softmin_tau) # (G,)
|
| 70 |
-
|
| 71 |
-
tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
|
| 72 |
-
thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
|
| 73 |
-
p_match = torch.sigmoid((thresh_t - d_pred) / tau_t).mean()
|
| 74 |
-
r_match = torch.sigmoid((thresh_t - d_gt) / tau_t).mean()
|
| 75 |
-
return 2.0 * p_match * r_match / (p_match + r_match + eps)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def soft_tube_iou_mc(
|
| 79 |
-
pred_segments: torch.Tensor,
|
| 80 |
-
gt_segments: torch.Tensor,
|
| 81 |
-
radius: float,
|
| 82 |
-
n_samples: int = 4096,
|
| 83 |
-
tau: float = 0.05,
|
| 84 |
-
seed: int = 0,
|
| 85 |
-
eps: float = 1e-8,
|
| 86 |
-
) -> torch.Tensor:
|
| 87 |
-
"""
|
| 88 |
-
Soft surrogate for volumetric tube IoU (edge_thresh in HSS).
|
| 89 |
-
|
| 90 |
-
Samples points uniformly in a padded bbox around {pred,gt} endpoints.
|
| 91 |
-
Occupancy is sigmoid((radius - d(x, segments))/tau).
|
| 92 |
-
IoU is approximated by mean(min(occ_p, occ_g)) / mean(max(occ_p, occ_g)).
|
| 93 |
-
"""
|
| 94 |
-
if pred_segments.numel() == 0 or gt_segments.numel() == 0:
|
| 95 |
-
return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
|
| 96 |
-
|
| 97 |
-
pts_all = torch.cat([pred_segments.reshape(-1, 3), gt_segments.reshape(-1, 3)], dim=0)
|
| 98 |
-
pad = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
|
| 99 |
-
lo = pts_all.min(dim=0).values - pad
|
| 100 |
-
hi = pts_all.max(dim=0).values + pad
|
| 101 |
-
|
| 102 |
-
gen = torch.Generator(device=pts_all.device)
|
| 103 |
-
gen.manual_seed(int(seed))
|
| 104 |
-
u = torch.rand((int(n_samples), 3), generator=gen, device=pts_all.device, dtype=pts_all.dtype)
|
| 105 |
-
x = lo[None, :] + u * (hi - lo)[None, :]
|
| 106 |
-
|
| 107 |
-
d_p = distance_to_segments(x, pred_segments, eps=eps)
|
| 108 |
-
d_g = distance_to_segments(x, gt_segments, eps=eps)
|
| 109 |
-
|
| 110 |
-
tau_t = torch.as_tensor(tau, device=pts_all.device, dtype=pts_all.dtype).clamp_min(1e-8)
|
| 111 |
-
rad_t = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
|
| 112 |
-
occ_p = torch.sigmoid((rad_t - d_p) / tau_t)
|
| 113 |
-
occ_g = torch.sigmoid((rad_t - d_g) / tau_t)
|
| 114 |
-
|
| 115 |
-
inter = torch.minimum(occ_p, occ_g).mean()
|
| 116 |
-
union = torch.maximum(occ_p, occ_g).mean().clamp_min(eps)
|
| 117 |
-
return inter / union
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def soft_hss(
|
| 121 |
-
pred_segments: torch.Tensor,
|
| 122 |
-
gt_segments: torch.Tensor,
|
| 123 |
-
gt_vertices: torch.Tensor,
|
| 124 |
-
vert_thresh: float = 0.5,
|
| 125 |
-
edge_thresh: float = 0.5,
|
| 126 |
-
tau: float = 0.05,
|
| 127 |
-
softmin_tau: float = 0.05,
|
| 128 |
-
n_samples: int = 4096,
|
| 129 |
-
seed: int = 0,
|
| 130 |
-
eps: float = 1e-8,
|
| 131 |
-
):
|
| 132 |
-
"""
|
| 133 |
-
Returns (soft_hss, soft_f1, soft_iou), all scalars in [0,1] (approximately).
|
| 134 |
-
"""
|
| 135 |
-
pred_vertices = pred_segments.reshape(-1, 3)
|
| 136 |
-
f1 = soft_vertex_f1(pred_vertices, gt_vertices, thresh=vert_thresh, tau=tau, softmin_tau=softmin_tau, eps=eps)
|
| 137 |
-
iou = soft_tube_iou_mc(
|
| 138 |
-
pred_segments,
|
| 139 |
-
gt_segments,
|
| 140 |
-
radius=edge_thresh,
|
| 141 |
-
n_samples=n_samples,
|
| 142 |
-
tau=tau,
|
| 143 |
-
seed=seed,
|
| 144 |
-
eps=eps,
|
| 145 |
-
)
|
| 146 |
-
denom = (f1 + iou).clamp_min(eps)
|
| 147 |
-
hss = 2.0 * f1 * iou / denom
|
| 148 |
-
return hss, f1, iou
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# ---------------------------------------------------------------------------
|
| 152 |
-
# Improved: Sinkhorn-matched vertex F1
|
| 153 |
-
# ---------------------------------------------------------------------------
|
| 154 |
-
#
|
| 155 |
-
# The original soft_vertex_f1 uses independent softmin nearest-neighbor
|
| 156 |
-
# distances, which allows multiple predicted vertices to claim the same GT
|
| 157 |
-
# vertex. This inflates precision and fails to penalize duplicate vertices --
|
| 158 |
-
# the exact failure mode that requires merge_vertices post-processing.
|
| 159 |
-
#
|
| 160 |
-
# This version uses Sinkhorn optimal transport to find a soft one-to-one
|
| 161 |
-
# assignment between predicted and GT vertices, then computes precision and
|
| 162 |
-
# recall from the matched distances. This is a better surrogate for the
|
| 163 |
-
# Hungarian matching used by the real HSS metric.
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def sinkhorn_vertex_f1(
|
| 167 |
-
pred_vertices: torch.Tensor,
|
| 168 |
-
gt_vertices: torch.Tensor,
|
| 169 |
-
thresh: float = 0.5,
|
| 170 |
-
tau: float = 0.05,
|
| 171 |
-
eps_sinkhorn: float = 0.05,
|
| 172 |
-
iters: int = 20,
|
| 173 |
-
eps: float = 1e-8,
|
| 174 |
-
) -> torch.Tensor:
|
| 175 |
-
"""Soft vertex F1 using Sinkhorn matching (better aligned with real HSS).
|
| 176 |
-
|
| 177 |
-
Instead of independent nearest-neighbor distances (which allow double-
|
| 178 |
-
claiming), this uses optimal transport to find a soft one-to-one assignment
|
| 179 |
-
between predicted and GT vertices.
|
| 180 |
-
|
| 181 |
-
Returns a differentiable scalar in [0, 1].
|
| 182 |
-
"""
|
| 183 |
-
if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
|
| 184 |
-
return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
|
| 185 |
-
|
| 186 |
-
P = pred_vertices.shape[0]
|
| 187 |
-
G = gt_vertices.shape[0]
|
| 188 |
-
|
| 189 |
-
# Pairwise distance matrix (P, G)
|
| 190 |
-
dist = torch.cdist(pred_vertices, gt_vertices)
|
| 191 |
-
|
| 192 |
-
# Sinkhorn with dustbin: (P+1) x (G+1)
|
| 193 |
-
# Dustbin cost = thresh (unmatched vertices are "at threshold distance")
|
| 194 |
-
dustbin = thresh
|
| 195 |
-
cost_pad = torch.full((P + 1, G + 1), dustbin, device=dist.device, dtype=dist.dtype)
|
| 196 |
-
cost_pad[:P, :G] = dist
|
| 197 |
-
cost_pad[-1, -1] = 0.0
|
| 198 |
-
|
| 199 |
-
# Uniform masses with dustbin slack
|
| 200 |
-
denom = float(P + G)
|
| 201 |
-
a = torch.full((P + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
|
| 202 |
-
b = torch.full((G + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
|
| 203 |
-
a[-1] = G / denom # pred dustbin absorbs unmatched GT
|
| 204 |
-
b[-1] = P / denom # GT dustbin absorbs unmatched pred
|
| 205 |
-
|
| 206 |
-
# Log-domain Sinkhorn
|
| 207 |
-
log_a = torch.log(a + 1e-9)
|
| 208 |
-
log_b = torch.log(b + 1e-9)
|
| 209 |
-
log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
|
| 210 |
-
log_u = torch.zeros_like(a)
|
| 211 |
-
log_v = torch.zeros_like(b)
|
| 212 |
-
for _ in range(iters):
|
| 213 |
-
log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
|
| 214 |
-
log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
|
| 215 |
-
|
| 216 |
-
# Transport plan (P+1, G+1)
|
| 217 |
-
transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
|
| 218 |
-
|
| 219 |
-
# Extract the non-dustbin transport (P, G) -- these are the soft assignments
|
| 220 |
-
T = transport[:P, :G]
|
| 221 |
-
|
| 222 |
-
# For each predicted vertex, its matched distance is the transport-weighted
|
| 223 |
-
# average distance to GT vertices
|
| 224 |
-
# Normalize rows to sum to 1 (how much of this pred is matched vs dustbin)
|
| 225 |
-
row_sums = T.sum(dim=1).clamp_min(eps)
|
| 226 |
-
matched_dist_pred = (T * dist).sum(dim=1) / row_sums # (P,)
|
| 227 |
-
match_weight_pred = row_sums * denom # how much of this pred is matched (0-1 ish)
|
| 228 |
-
|
| 229 |
-
# Same for GT vertices (column perspective)
|
| 230 |
-
col_sums = T.sum(dim=0).clamp_min(eps)
|
| 231 |
-
matched_dist_gt = (T * dist).sum(dim=0) / col_sums # (G,)
|
| 232 |
-
match_weight_gt = col_sums * denom
|
| 233 |
-
|
| 234 |
-
# Soft precision: fraction of pred vertices that are matched AND within threshold
|
| 235 |
-
tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
|
| 236 |
-
thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
|
| 237 |
-
|
| 238 |
-
prec_per = match_weight_pred * torch.sigmoid((thresh_t - matched_dist_pred) / tau_t)
|
| 239 |
-
precision = prec_per.mean()
|
| 240 |
-
|
| 241 |
-
# Soft recall: fraction of GT vertices that are matched AND within threshold
|
| 242 |
-
rec_per = match_weight_gt * torch.sigmoid((thresh_t - matched_dist_gt) / tau_t)
|
| 243 |
-
recall = rec_per.mean()
|
| 244 |
-
|
| 245 |
-
return 2.0 * precision * recall / (precision + recall + eps)
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
# ---------------------------------------------------------------------------
|
| 249 |
-
# Improved: Segment-sampled tube IoU
|
| 250 |
-
# ---------------------------------------------------------------------------
|
| 251 |
-
#
|
| 252 |
-
# The original soft_tube_iou_mc samples random points in the bounding box,
|
| 253 |
-
# wasting most samples in empty space. This version samples along the segments
|
| 254 |
-
# themselves, concentrating gradient signal where it matters.
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def _sample_along_segments(segments: torch.Tensor, n_per_seg: int = 64) -> torch.Tensor:
|
| 258 |
-
"""Sample n_per_seg points uniformly along each segment.
|
| 259 |
-
|
| 260 |
-
segments: (S, 2, 3)
|
| 261 |
-
returns: (S * n_per_seg, 3)
|
| 262 |
-
"""
|
| 263 |
-
t = torch.linspace(0, 1, n_per_seg, device=segments.device, dtype=segments.dtype)
|
| 264 |
-
# (S, 1, 3) + (1, N, 1) * (S, 1, 3) -> (S, N, 3)
|
| 265 |
-
a = segments[:, 0:1, :]
|
| 266 |
-
b = segments[:, 1:2, :]
|
| 267 |
-
pts = a + t[None, :, None] * (b - a)
|
| 268 |
-
return pts.reshape(-1, 3)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
def segment_sampled_tube_iou(
|
| 272 |
-
pred_segments: torch.Tensor,
|
| 273 |
-
gt_segments: torch.Tensor,
|
| 274 |
-
radius: float = 0.5,
|
| 275 |
-
n_per_seg: int = 64,
|
| 276 |
-
tau: float = 0.05,
|
| 277 |
-
eps: float = 1e-8,
|
| 278 |
-
) -> torch.Tensor:
|
| 279 |
-
"""Soft tube IoU by sampling along segments instead of in the bounding box.
|
| 280 |
-
|
| 281 |
-
Samples points along predicted and GT segments, then checks what fraction
|
| 282 |
-
of each set falls within radius of the other. More sample-efficient than
|
| 283 |
-
bbox Monte Carlo and gives better gradients.
|
| 284 |
-
|
| 285 |
-
Returns a differentiable scalar in [0, 1].
|
| 286 |
-
"""
|
| 287 |
-
if pred_segments.numel() == 0 or gt_segments.numel() == 0:
|
| 288 |
-
return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
|
| 289 |
-
|
| 290 |
-
pred_pts = _sample_along_segments(pred_segments, n_per_seg)
|
| 291 |
-
gt_pts = _sample_along_segments(gt_segments, n_per_seg)
|
| 292 |
-
|
| 293 |
-
tau_t = torch.as_tensor(tau, device=pred_pts.device, dtype=pred_pts.dtype).clamp_min(1e-8)
|
| 294 |
-
rad_t = torch.as_tensor(radius, device=pred_pts.device, dtype=pred_pts.dtype)
|
| 295 |
-
|
| 296 |
-
# Precision: fraction of pred points within radius of any GT segment
|
| 297 |
-
d_pred = distance_to_segments(pred_pts, gt_segments, eps=eps)
|
| 298 |
-
prec = torch.sigmoid((rad_t - d_pred) / tau_t).mean()
|
| 299 |
-
|
| 300 |
-
# Recall: fraction of GT points within radius of any pred segment
|
| 301 |
-
d_gt = distance_to_segments(gt_pts, pred_segments, eps=eps)
|
| 302 |
-
rec = torch.sigmoid((rad_t - d_gt) / tau_t).mean()
|
| 303 |
-
|
| 304 |
-
# Soft IoU from precision and recall:
|
| 305 |
-
# IoU = intersection/union = (P*R) / (P + R - P*R) for occupancy overlap
|
| 306 |
-
return prec * rec / (prec + rec - prec * rec + eps)
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def soft_hss_v2(
|
| 310 |
-
pred_segments: torch.Tensor,
|
| 311 |
-
gt_segments: torch.Tensor,
|
| 312 |
-
gt_vertices: torch.Tensor,
|
| 313 |
-
vert_thresh: float = 0.5,
|
| 314 |
-
edge_thresh: float = 0.5,
|
| 315 |
-
tau: float = 0.05,
|
| 316 |
-
sinkhorn_eps: float = 0.05,
|
| 317 |
-
sinkhorn_iters: int = 20,
|
| 318 |
-
n_per_seg: int = 64,
|
| 319 |
-
eps: float = 1e-8,
|
| 320 |
-
):
|
| 321 |
-
"""Improved soft HSS using Sinkhorn vertex matching + segment-sampled IoU.
|
| 322 |
-
|
| 323 |
-
Returns (soft_hss, soft_f1, soft_iou).
|
| 324 |
-
"""
|
| 325 |
-
pred_vertices = pred_segments.reshape(-1, 3)
|
| 326 |
-
f1 = sinkhorn_vertex_f1(
|
| 327 |
-
pred_vertices, gt_vertices,
|
| 328 |
-
thresh=vert_thresh, tau=tau,
|
| 329 |
-
eps_sinkhorn=sinkhorn_eps, iters=sinkhorn_iters, eps=eps,
|
| 330 |
-
)
|
| 331 |
-
iou = segment_sampled_tube_iou(
|
| 332 |
-
pred_segments, gt_segments,
|
| 333 |
-
radius=edge_thresh, n_per_seg=n_per_seg, tau=tau, eps=eps,
|
| 334 |
-
)
|
| 335 |
-
denom = (f1 + iou).clamp_min(eps)
|
| 336 |
-
hss = 2.0 * f1 * iou / denom
|
| 337 |
-
return hss, f1, iou
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
# ---------------------------------------------------------------------------
|
| 342 |
-
# Batched versions for training speed
|
| 343 |
-
# ---------------------------------------------------------------------------
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
def batched_sinkhorn_vertex_f1(
|
| 347 |
-
pred_segments: torch.Tensor,
|
| 348 |
-
gt_pad: torch.Tensor,
|
| 349 |
-
gt_mask: torch.Tensor,
|
| 350 |
-
thresh: float | torch.Tensor = 0.5,
|
| 351 |
-
tau: float | torch.Tensor = 0.05,
|
| 352 |
-
eps_sinkhorn: float = 0.05,
|
| 353 |
-
iters: int = 10,
|
| 354 |
-
eps: float = 1e-8,
|
| 355 |
-
) -> torch.Tensor:
|
| 356 |
-
"""Batched Sinkhorn vertex F1 loss.
|
| 357 |
-
|
| 358 |
-
Args:
|
| 359 |
-
pred_segments: [B, S, 2, 3] predicted segments
|
| 360 |
-
gt_pad: [B, M, 2, 3] padded GT segments
|
| 361 |
-
gt_mask: [B, M] bool mask (True = valid GT segment)
|
| 362 |
-
thresh: distance threshold for a vertex match (scalar or [B])
|
| 363 |
-
tau: sigmoid temperature (scalar or [B])
|
| 364 |
-
Returns:
|
| 365 |
-
[B] per-sample (1 - F1) loss
|
| 366 |
-
"""
|
| 367 |
-
B, S = pred_segments.shape[:2]
|
| 368 |
-
M = gt_pad.shape[1]
|
| 369 |
-
P = S * 2 # pred vertices (both endpoints)
|
| 370 |
-
|
| 371 |
-
# Allow per-sample thresh and tau ([B] tensors or scalars)
|
| 372 |
-
thresh_t = torch.as_tensor(thresh, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 373 |
-
if thresh_t.dim() == 0:
|
| 374 |
-
thresh_t = thresh_t.expand(B)
|
| 375 |
-
tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 376 |
-
if tau_t.dim() == 0:
|
| 377 |
-
tau_t = tau_t.expand(B)
|
| 378 |
-
tau_t = tau_t.clamp_min(1e-8)
|
| 379 |
-
|
| 380 |
-
pred_verts = pred_segments.reshape(B, P, 3)
|
| 381 |
-
gt_verts = gt_pad.reshape(B, M * 2, 3) # will mask invalid ones
|
| 382 |
-
|
| 383 |
-
# Build GT vertex mask: each valid segment contributes 2 vertices
|
| 384 |
-
gt_vert_mask = gt_mask.unsqueeze(2).expand(B, M, 2).reshape(B, M * 2)
|
| 385 |
-
G = M * 2
|
| 386 |
-
|
| 387 |
-
# Pairwise distances [B, P, G]
|
| 388 |
-
dist = torch.linalg.norm(
|
| 389 |
-
pred_verts.unsqueeze(2) - gt_verts.unsqueeze(1), dim=-1)
|
| 390 |
-
|
| 391 |
-
# Mask invalid GT with high distance
|
| 392 |
-
dist = torch.where(gt_vert_mask.unsqueeze(1), dist, thresh_t[:, None, None] * 10.0)
|
| 393 |
-
|
| 394 |
-
# Sinkhorn matching: [B, P+1, G+1]
|
| 395 |
-
cost_pad = thresh_t[:, None, None].expand(B, P + 1, G + 1).clone()
|
| 396 |
-
cost_pad[:, :P, :G] = dist
|
| 397 |
-
cost_pad[:, -1, -1] = 0.0
|
| 398 |
-
|
| 399 |
-
gt_counts = gt_vert_mask.sum(dim=1).float() # [B]
|
| 400 |
-
n = float(P)
|
| 401 |
-
denom = n + gt_counts # [B]
|
| 402 |
-
|
| 403 |
-
a = (1.0 / denom).unsqueeze(1).expand(B, P + 1).clone()
|
| 404 |
-
a[:, -1] = gt_counts / denom
|
| 405 |
-
b = (1.0 / denom).unsqueeze(1).expand(B, G + 1).clone()
|
| 406 |
-
b[:, -1] = n / denom
|
| 407 |
-
b[:, :G] = b[:, :G] * gt_vert_mask.float()
|
| 408 |
-
|
| 409 |
-
log_a = torch.log(a + 1e-9)
|
| 410 |
-
log_b = torch.log(b + 1e-9)
|
| 411 |
-
log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
|
| 412 |
-
log_u = torch.zeros_like(a)
|
| 413 |
-
log_v = torch.zeros_like(b)
|
| 414 |
-
|
| 415 |
-
for _ in range(iters):
|
| 416 |
-
log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
|
| 417 |
-
log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
|
| 418 |
-
|
| 419 |
-
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
|
| 420 |
-
T = transport[:, :P, :G] # [B, P, G]
|
| 421 |
-
|
| 422 |
-
# Matched distances
|
| 423 |
-
row_sums = T.sum(dim=2).clamp_min(eps)
|
| 424 |
-
matched_d_pred = (T * dist).sum(dim=2) / row_sums # [B, P]
|
| 425 |
-
w_pred = row_sums * denom.unsqueeze(1)
|
| 426 |
-
|
| 427 |
-
col_sums = T.sum(dim=1).clamp_min(eps)
|
| 428 |
-
matched_d_gt = (T * dist).sum(dim=1) / col_sums # [B, G]
|
| 429 |
-
w_gt = col_sums * denom.unsqueeze(1)
|
| 430 |
-
|
| 431 |
-
precision = (w_pred * torch.sigmoid((thresh_t[:, None] - matched_d_pred) / tau_t[:, None])).mean(dim=1)
|
| 432 |
-
recall_raw = w_gt * torch.sigmoid((thresh_t[:, None] - matched_d_gt) / tau_t[:, None])
|
| 433 |
-
# Mask invalid GT vertices in recall
|
| 434 |
-
recall = (recall_raw * gt_vert_mask.float()).sum(dim=1) / gt_counts.clamp_min(1.0)
|
| 435 |
-
|
| 436 |
-
f1 = 2.0 * precision * recall / (precision + recall + eps)
|
| 437 |
-
return 1.0 - f1 # return loss (1 - F1)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
def batched_segment_sampled_iou(
|
| 442 |
-
pred_segments: torch.Tensor,
|
| 443 |
-
gt_pad: torch.Tensor,
|
| 444 |
-
gt_mask: torch.Tensor,
|
| 445 |
-
radius: float | torch.Tensor = 0.5,
|
| 446 |
-
n_per_seg: int = 32,
|
| 447 |
-
tau: float | torch.Tensor = 0.05,
|
| 448 |
-
eps: float = 1e-8,
|
| 449 |
-
) -> torch.Tensor:
|
| 450 |
-
"""Batched segment-sampled tube IoU loss.
|
| 451 |
-
|
| 452 |
-
Returns [B] per-sample (1 - IoU) loss.
|
| 453 |
-
"""
|
| 454 |
-
B, S = pred_segments.shape[:2]
|
| 455 |
-
M = gt_pad.shape[1]
|
| 456 |
-
|
| 457 |
-
# Allow per-sample radius and tau ([B] tensors or scalars)
|
| 458 |
-
rad_t = torch.as_tensor(radius, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 459 |
-
if rad_t.dim() == 0:
|
| 460 |
-
rad_t = rad_t.expand(B)
|
| 461 |
-
tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 462 |
-
if tau_t.dim() == 0:
|
| 463 |
-
tau_t = tau_t.expand(B)
|
| 464 |
-
tau_t = tau_t.clamp_min(1e-8)
|
| 465 |
-
|
| 466 |
-
# Sample points along segments
|
| 467 |
-
t = torch.linspace(0, 1, n_per_seg, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 468 |
-
|
| 469 |
-
# Pred points: [B, S*n_per_seg, 3]
|
| 470 |
-
pa = pred_segments[:, :, 0:1, :] # [B, S, 1, 3]
|
| 471 |
-
pb = pred_segments[:, :, 1:2, :]
|
| 472 |
-
pred_pts = (pa + t[None, None, :, None] * (pb - pa)).reshape(B, S * n_per_seg, 3)
|
| 473 |
-
|
| 474 |
-
# GT points: [B, M*n_per_seg, 3]
|
| 475 |
-
ga = gt_pad[:, :, 0:1, :]
|
| 476 |
-
gb = gt_pad[:, :, 1:2, :]
|
| 477 |
-
gt_pts = (ga + t[None, None, :, None] * (gb - ga)).reshape(B, M * n_per_seg, 3)
|
| 478 |
-
|
| 479 |
-
# For each pred point, min distance to any GT segment endpoint samples
|
| 480 |
-
d_pred_to_gt = torch.cdist(pred_pts, gt_pts) # [B, S*n, M*n]
|
| 481 |
-
d_pred = d_pred_to_gt.min(dim=2).values # [B, S*n]
|
| 482 |
-
prec = torch.sigmoid((rad_t[:, None] - d_pred) / tau_t[:, None]).mean(dim=1) # [B]
|
| 483 |
-
|
| 484 |
-
d_gt_to_pred = d_pred_to_gt.min(dim=1).values # [B, M*n]
|
| 485 |
-
# Mask invalid GT points
|
| 486 |
-
gt_pt_mask = gt_mask.unsqueeze(2).expand(B, M, n_per_seg).reshape(B, M * n_per_seg)
|
| 487 |
-
rec_raw = torch.sigmoid((rad_t[:, None] - d_gt_to_pred) / tau_t[:, None])
|
| 488 |
-
rec = (rec_raw * gt_pt_mask.float()).sum(dim=1) / gt_pt_mask.float().sum(dim=1).clamp_min(1.0)
|
| 489 |
-
|
| 490 |
-
iou = prec * rec / (prec + rec - prec * rec + eps)
|
| 491 |
-
return 1.0 - iou # return loss
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
def batched_soft_hss_v2(pred_segments, gt_pad, gt_mask,
|
| 495 |
-
vert_thresh=0.5, edge_thresh=0.5, tau=0.05,
|
| 496 |
-
sinkhorn_iters=10, n_per_seg=32):
|
| 497 |
-
"""Batched soft HSS loss. Returns [B] per-sample (1 - HSS)."""
|
| 498 |
-
f1_loss = batched_sinkhorn_vertex_f1(
|
| 499 |
-
pred_segments, gt_pad, gt_mask,
|
| 500 |
-
thresh=vert_thresh, tau=tau, iters=sinkhorn_iters)
|
| 501 |
-
iou_loss = batched_segment_sampled_iou(
|
| 502 |
-
pred_segments, gt_pad, gt_mask,
|
| 503 |
-
radius=edge_thresh, n_per_seg=n_per_seg, tau=tau)
|
| 504 |
-
f1 = 1.0 - f1_loss
|
| 505 |
-
iou = 1.0 - iou_loss
|
| 506 |
-
hss = 2.0 * f1 * iou / (f1 + iou + 1e-8)
|
| 507 |
-
return 1.0 - hss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/train.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Training script for S23DR 2026.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python -m s23dr_2026_example.train --cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train --steps 80000 --aug-rotate
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path as _Path
|
| 12 |
+
if __package__ is None or __package__ == "":
|
| 13 |
+
_here = _Path(__file__).resolve().parent
|
| 14 |
+
if str(_here.parent) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(_here.parent))
|
| 16 |
+
__package__ = _here.name
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import gc
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import subprocess
|
| 23 |
+
import time
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from .tokenizer import EdgeDepthSequenceConfig
|
| 30 |
+
from .model import EdgeDepthSegmentsModel
|
| 31 |
+
from .data import build_loader, build_tokens
|
| 32 |
+
from .losses import compute_loss, _loss_inner
|
| 33 |
+
|
| 34 |
+
# Re-export for eval scripts
|
| 35 |
+
from .data import HFCachedDataset, collate as _collate # noqa: F401
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Main
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
p = argparse.ArgumentParser(description="S23DR 2026 training")
|
| 44 |
+
p.add_argument("--cache-dir", default=None, help="HF dataset path (hf://repo:split)")
|
| 45 |
+
p.add_argument("--val-cache-dir", default="", help="Separate cache for validation")
|
| 46 |
+
p.add_argument("--seq-len", type=int, default=2048,
|
| 47 |
+
help="Input sequence length (2048 or 4096, must match dataset)")
|
| 48 |
+
p.add_argument("--arch", choices=["perceiver", "transformer"], default="perceiver",
|
| 49 |
+
help="perceiver=latent bottleneck, transformer=full self-attention encoder")
|
| 50 |
+
p.add_argument("--segments", type=int, default=32)
|
| 51 |
+
p.add_argument("--hidden", type=int, default=128)
|
| 52 |
+
p.add_argument("--ff", type=int, default=512)
|
| 53 |
+
p.add_argument("--latent-tokens", type=int, default=128)
|
| 54 |
+
p.add_argument("--latent-layers", type=int, default=7)
|
| 55 |
+
p.add_argument("--encoder-layers", type=int, default=4,
|
| 56 |
+
help="Encoder layers (transformer arch only)")
|
| 57 |
+
p.add_argument("--pre-encoder-layers", type=int, default=0,
|
| 58 |
+
help="Self-attn layers on full token sequence before perceiver bottleneck")
|
| 59 |
+
p.add_argument("--decoder-layers", type=int, default=3)
|
| 60 |
+
p.add_argument("--decoder-input-xattn", action="store_true",
|
| 61 |
+
help="Add cross-attention from segment queries to input tokens in each decoder layer")
|
| 62 |
+
p.add_argument("--qk-norm", action="store_true",
|
| 63 |
+
help="Normalize Q and K per-head with learned temperature (stabilizes wide models)")
|
| 64 |
+
p.add_argument("--qk-norm-type", choices=["l2", "rms"], default="l2",
|
| 65 |
+
help="QK-norm type: l2 (unit sphere) or rms (RMSNorm, preserves magnitudes)")
|
| 66 |
+
p.add_argument("--learnable-fourier", action="store_true",
|
| 67 |
+
help="Make Fourier positional encoding learnable (vs fixed random)")
|
| 68 |
+
p.add_argument("--num-heads", type=int, default=4, help="Attention heads")
|
| 69 |
+
p.add_argument("--kv-heads-cross", type=int, default=2,
|
| 70 |
+
help="KV heads for cross-attention (GQA; 0 = standard MHA)")
|
| 71 |
+
p.add_argument("--kv-heads-self", type=int, default=2,
|
| 72 |
+
help="KV heads for self-attention (GQA; 0 = standard MHA)")
|
| 73 |
+
p.add_argument("--cross-attn-interval", type=int, default=4,
|
| 74 |
+
help="Perceiver cross-attention frequency (every N latent layers)")
|
| 75 |
+
p.add_argument("--dropout", type=float, default=0.1)
|
| 76 |
+
p.add_argument("--weight-decay", type=float, default=0.01, help="AdamW weight decay")
|
| 77 |
+
p.add_argument("--steps", type=int, default=5000)
|
| 78 |
+
p.add_argument("--batch-size", type=int, default=32)
|
| 79 |
+
p.add_argument("--lr", type=float, default=3e-4)
|
| 80 |
+
p.add_argument("--adam-betas", default="0.9,0.95", help="AdamW beta1,beta2")
|
| 81 |
+
p.add_argument("--warmup", type=int, default=200, help="LR warmup steps")
|
| 82 |
+
p.add_argument("--cosine-decay", action="store_true",
|
| 83 |
+
help="Cosine decay LR after warmup (to lr*0.01 at end)")
|
| 84 |
+
p.add_argument("--cooldown-start", type=int, default=0,
|
| 85 |
+
help="Step to begin linear cooldown to lr*0.01 (0=disabled, constant LR after warmup)")
|
| 86 |
+
p.add_argument("--cooldown-steps", type=int, default=0,
|
| 87 |
+
help="Number of steps for linear cooldown (0=no cooldown)")
|
| 88 |
+
p.add_argument("--seed", type=int, default=7)
|
| 89 |
+
p.add_argument("--deterministic", action="store_true",
|
| 90 |
+
help="Force deterministic mode (disables torch.compile, slower but bit-reproducible)")
|
| 91 |
+
p.add_argument("--varifold-weight", type=float, default=0.0)
|
| 92 |
+
p.add_argument("--varifold-cross-only", action="store_true",
|
| 93 |
+
help="Drop varifold self-energy (avoids O(S^2) spike, sinkhorn handles repulsion)")
|
| 94 |
+
p.add_argument("--sinkhorn-weight", type=float, default=1.0)
|
| 95 |
+
p.add_argument("--sinkhorn-eps", type=float, default=0.1,
|
| 96 |
+
help="Sinkhorn regularization (larger = softer matching, stronger gradients)")
|
| 97 |
+
p.add_argument("--sinkhorn-eps-start", type=float, default=None,
|
| 98 |
+
help="Starting eps for epsilon annealing (anneals to --sinkhorn-eps). None=no annealing.")
|
| 99 |
+
p.add_argument("--sinkhorn-eps-schedule", choices=["linear", "sqrt", "none"], default="none",
|
| 100 |
+
help="Eps annealing schedule: linear, sqrt, or none (default: no annealing)")
|
| 101 |
+
p.add_argument("--sinkhorn-iters", type=int, default=20,
|
| 102 |
+
help="Sinkhorn iterations")
|
| 103 |
+
p.add_argument("--sinkhorn-dustbin", type=float, default=0.3,
|
| 104 |
+
help="Sinkhorn dustbin cost in normalized space")
|
| 105 |
+
p.add_argument("--endpoint-weight", type=float, default=0.0,
|
| 106 |
+
help="Weight for endpoint distance loss (sinkhorn-matched, symmetric)")
|
| 107 |
+
p.add_argument("--endpoint-warmup", type=int, default=0,
|
| 108 |
+
help="Steps to linearly warm up endpoint weight from 0 (0=instant)")
|
| 109 |
+
p.add_argument("--aug-rotate", action="store_true")
|
| 110 |
+
p.add_argument("--aug-jitter", type=float, default=0.0,
|
| 111 |
+
help="Point position jitter std in normalized space (0=disabled, try 0.005)")
|
| 112 |
+
p.add_argument("--aug-drop", type=float, default=0.0,
|
| 113 |
+
help="Fraction of points to randomly drop (0=disabled, try 0.1)")
|
| 114 |
+
p.add_argument("--aug-flip", action="store_true",
|
| 115 |
+
help="Random mirror along X axis (50%% chance)")
|
| 116 |
+
p.add_argument("--rms-norm", action="store_true", default=True,
|
| 117 |
+
help="Use RMSNorm (default). Use --no-rms-norm for LayerNorm")
|
| 118 |
+
p.add_argument("--no-rms-norm", dest="rms_norm", action="store_false")
|
| 119 |
+
p.add_argument("--activation", default="gelu", help="FFN activation: gelu, relu, relu_sq")
|
| 120 |
+
p.add_argument("--behind-emb-dim", type=int, default=8,
|
| 121 |
+
help="Behind-gestalt embedding dim (0 to disable)")
|
| 122 |
+
p.add_argument("--vote-features", action="store_true",
|
| 123 |
+
help="Add n_views_voted + vote_frac as raw token features (requires v2 data)")
|
| 124 |
+
p.add_argument("--segment-param", choices=["midpoint_halfvec", "midpoint_dir_len"],
|
| 125 |
+
default="midpoint_halfvec",
|
| 126 |
+
help="Output parameterization: halfvec (default) or decoupled direction+length")
|
| 127 |
+
p.add_argument("--length-floor", type=float, default=0.0,
|
| 128 |
+
help="Minimum segment length for midpoint_dir_len (0=no floor)")
|
| 129 |
+
p.add_argument("--segment-conf", action="store_true",
|
| 130 |
+
help="Add per-segment confidence head (use with --conf-thresh at eval)")
|
| 131 |
+
p.add_argument("--conf-weight", type=float, default=0.0,
|
| 132 |
+
help="Weight for confidence loss (requires --segment-conf)")
|
| 133 |
+
p.add_argument("--conf-mode", choices=["sinkhorn", "sinkhorn_detach"], default="sinkhorn",
|
| 134 |
+
help="Confidence training: 'match'=BCE, 'sinkhorn'=OT mass, 'sinkhorn_detach'=OT mass (detached)")
|
| 135 |
+
p.add_argument("--conf-clamp-min", type=float, default=None,
|
| 136 |
+
help="Clamp conf logits to this minimum before sigmoid (e.g., -5)")
|
| 137 |
+
p.add_argument("--conf-head-wd", type=float, default=None,
|
| 138 |
+
help="Separate weight decay for conf head (default: same as other params)")
|
| 139 |
+
p.add_argument("--ema-decay", type=float, default=0.0,
|
| 140 |
+
help="EMA decay rate (0=disabled, try 0.9999). Saves EMA weights in checkpoints.")
|
| 141 |
+
p.add_argument("--out-dir", default=str(_Path(__file__).resolve().parent / "runs"))
|
| 142 |
+
p.add_argument("--resume", default="")
|
| 143 |
+
p.add_argument("--cpu", action="store_true")
|
| 144 |
+
p.add_argument("--args-from", default=None,
|
| 145 |
+
help="Load defaults from a run's args.json (CLI flags override)")
|
| 146 |
+
|
| 147 |
+
# If --args-from is specified, load defaults from that JSON file first,
|
| 148 |
+
# then let CLI flags override.
|
| 149 |
+
raw_args = p.parse_args()
|
| 150 |
+
if raw_args.args_from is not None:
|
| 151 |
+
import json as _json
|
| 152 |
+
args_path = _Path(raw_args.args_from)
|
| 153 |
+
if not args_path.exists():
|
| 154 |
+
raise FileNotFoundError(f"--args-from file not found: {args_path}")
|
| 155 |
+
saved = _json.loads(args_path.read_text())
|
| 156 |
+
valid_dests = {a.dest for a in p._actions}
|
| 157 |
+
defaults = {}
|
| 158 |
+
for k, v in saved.items():
|
| 159 |
+
if k in valid_dests and k != "args_from":
|
| 160 |
+
defaults[k] = v
|
| 161 |
+
p.set_defaults(**defaults)
|
| 162 |
+
args = p.parse_args()
|
| 163 |
+
print(f"Loaded defaults from {args_path} (CLI flags override)")
|
| 164 |
+
else:
|
| 165 |
+
args = raw_args
|
| 166 |
+
|
| 167 |
+
# Validate required args
|
| 168 |
+
if not args.cache_dir:
|
| 169 |
+
p.error("--cache-dir is required (either directly or via --args-from)")
|
| 170 |
+
|
| 171 |
+
# Validate arg compatibility
|
| 172 |
+
if args.arch == "transformer":
|
| 173 |
+
perceiver_only = []
|
| 174 |
+
if args.latent_tokens != 128:
|
| 175 |
+
perceiver_only.append(f"--latent-tokens={args.latent_tokens}")
|
| 176 |
+
if args.latent_layers != 7:
|
| 177 |
+
perceiver_only.append(f"--latent-layers={args.latent_layers}")
|
| 178 |
+
if args.pre_encoder_layers != 0:
|
| 179 |
+
perceiver_only.append(f"--pre-encoder-layers={args.pre_encoder_layers}")
|
| 180 |
+
if args.cross_attn_interval != 4:
|
| 181 |
+
perceiver_only.append(f"--cross-attn-interval={args.cross_attn_interval}")
|
| 182 |
+
if perceiver_only:
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"Args {', '.join(perceiver_only)} have no effect with --arch transformer. "
|
| 185 |
+
f"Use --arch perceiver or remove them.")
|
| 186 |
+
if args.conf_weight > 0 and not args.segment_conf:
|
| 187 |
+
raise ValueError("--conf-weight requires --segment-conf")
|
| 188 |
+
if args.conf_mode in ("sinkhorn", "sinkhorn_detach") and args.sinkhorn_weight == 0:
|
| 189 |
+
raise ValueError("--conf-mode sinkhorn requires --sinkhorn-weight > 0")
|
| 190 |
+
if args.cosine_decay and args.cooldown_start > 0:
|
| 191 |
+
raise ValueError("--cosine-decay and --cooldown-start are mutually exclusive")
|
| 192 |
+
|
| 193 |
+
device = torch.device("cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 194 |
+
print(f"Device: {device}")
|
| 195 |
+
torch.manual_seed(args.seed)
|
| 196 |
+
np.random.seed(args.seed)
|
| 197 |
+
|
| 198 |
+
# Output
|
| 199 |
+
import hashlib, os
|
| 200 |
+
args_hash = hashlib.md5(json.dumps(vars(args), sort_keys=True).encode()).hexdigest()[:4]
|
| 201 |
+
run_tag = time.strftime("%Y%m%d_%H%M%S") + f"_{args_hash}_{os.getpid() % 10000:04d}"
|
| 202 |
+
out_dir = Path(args.out_dir) / run_tag
|
| 203 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 204 |
+
(out_dir / "checkpoints").mkdir(exist_ok=True)
|
| 205 |
+
|
| 206 |
+
# Tee stdout/stderr to run dir
|
| 207 |
+
import sys as _sys
|
| 208 |
+
_log_path = out_dir / "train.log"
|
| 209 |
+
class _Tee:
|
| 210 |
+
def __init__(self, path, stream):
|
| 211 |
+
self._file = open(path, "a")
|
| 212 |
+
self._stream = stream
|
| 213 |
+
def write(self, data):
|
| 214 |
+
self._stream.write(data)
|
| 215 |
+
self._file.write(data)
|
| 216 |
+
self._file.flush()
|
| 217 |
+
def flush(self):
|
| 218 |
+
self._stream.flush()
|
| 219 |
+
self._file.flush()
|
| 220 |
+
_sys.stdout = _Tee(_log_path, _sys.stdout)
|
| 221 |
+
_sys.stderr = _Tee(_log_path, _sys.stderr)
|
| 222 |
+
|
| 223 |
+
git_sha = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True,
|
| 224 |
+
cwd=str(_Path(__file__).parent)).stdout.strip()
|
| 225 |
+
git_dirty = subprocess.run(["git", "diff", "--quiet"], capture_output=True,
|
| 226 |
+
cwd=str(_Path(__file__).parent)).returncode != 0
|
| 227 |
+
run_info = {**vars(args), "git_sha": git_sha, "git_dirty": git_dirty}
|
| 228 |
+
(out_dir / "args.json").write_text(json.dumps(run_info, indent=2, sort_keys=True) + "\n")
|
| 229 |
+
|
| 230 |
+
# Set varifold cross-only mode before compile
|
| 231 |
+
if args.varifold_cross_only:
|
| 232 |
+
from . import losses as L
|
| 233 |
+
L.VARIFOLD_CROSS_ONLY = True
|
| 234 |
+
print("Varifold: cross-only mode (no self-energy)")
|
| 235 |
+
|
| 236 |
+
# Model
|
| 237 |
+
seq_len = args.seq_len
|
| 238 |
+
norm_class = torch.nn.RMSNorm if args.rms_norm else None
|
| 239 |
+
seq_cfg = EdgeDepthSequenceConfig(seq_len=seq_len)
|
| 240 |
+
model = EdgeDepthSegmentsModel(
|
| 241 |
+
seq_cfg=seq_cfg, segments=args.segments, hidden=args.hidden,
|
| 242 |
+
num_heads=args.num_heads, kv_heads_cross=args.kv_heads_cross,
|
| 243 |
+
kv_heads_self=args.kv_heads_self,
|
| 244 |
+
dim_feedforward=args.ff, dropout=args.dropout,
|
| 245 |
+
latent_tokens=args.latent_tokens, latent_layers=args.latent_layers,
|
| 246 |
+
decoder_layers=args.decoder_layers, cross_attn_interval=args.cross_attn_interval,
|
| 247 |
+
norm_class=norm_class, activation=args.activation,
|
| 248 |
+
segment_conf=args.segment_conf,
|
| 249 |
+
segment_param=args.segment_param,
|
| 250 |
+
length_floor=args.length_floor,
|
| 251 |
+
arch=args.arch, encoder_layers=args.encoder_layers,
|
| 252 |
+
pre_encoder_layers=args.pre_encoder_layers,
|
| 253 |
+
behind_emb_dim=args.behind_emb_dim,
|
| 254 |
+
use_vote_features=args.vote_features,
|
| 255 |
+
decoder_input_xattn=args.decoder_input_xattn,
|
| 256 |
+
qk_norm=args.qk_norm,
|
| 257 |
+
qk_norm_type=args.qk_norm_type,
|
| 258 |
+
learnable_fourier=args.learnable_fourier,
|
| 259 |
+
).to(device)
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
from torchinfo import summary
|
| 263 |
+
summary(model.segmenter,
|
| 264 |
+
input_data=[torch.zeros(1, seq_len, model.tokenizer.out_dim, device=device),
|
| 265 |
+
torch.ones(1, seq_len, device=device, dtype=torch.bool)],
|
| 266 |
+
col_names=("input_size", "output_size", "num_params"), verbose=1)
|
| 267 |
+
except ImportError:
|
| 268 |
+
pass
|
| 269 |
+
print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
|
| 270 |
+
|
| 271 |
+
# Compile (skip in deterministic mode for bit-reproducibility)
|
| 272 |
+
torch.set_float32_matmul_precision("high")
|
| 273 |
+
if args.deterministic:
|
| 274 |
+
torch.use_deterministic_algorithms(True)
|
| 275 |
+
torch.backends.cudnn.deterministic = True
|
| 276 |
+
torch.backends.cudnn.benchmark = False
|
| 277 |
+
import os
|
| 278 |
+
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")
|
| 279 |
+
print("Deterministic mode: no torch.compile, bit-reproducible but ~3x slower")
|
| 280 |
+
elif device.type == "cuda":
|
| 281 |
+
model.segmenter = torch.compile(model.segmenter, mode="reduce-overhead", fullgraph=True)
|
| 282 |
+
from . import losses as L
|
| 283 |
+
L._loss_fn = torch.compile(_loss_inner, mode="reduce-overhead", fullgraph=True)
|
| 284 |
+
print("Compiled model + loss (reduce-overhead, fullgraph)")
|
| 285 |
+
|
| 286 |
+
# EMA
|
| 287 |
+
ema_model = None
|
| 288 |
+
if args.ema_decay > 0:
|
| 289 |
+
from copy import deepcopy
|
| 290 |
+
ema_model = deepcopy(model).eval()
|
| 291 |
+
for p_ema in ema_model.parameters():
|
| 292 |
+
p_ema.requires_grad_(False)
|
| 293 |
+
print(f"EMA enabled (decay={args.ema_decay})")
|
| 294 |
+
|
| 295 |
+
# Resume
|
| 296 |
+
start_step = 0
|
| 297 |
+
if args.resume:
|
| 298 |
+
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
|
| 299 |
+
try:
|
| 300 |
+
model.load_state_dict(ckpt["model"])
|
| 301 |
+
except RuntimeError:
|
| 302 |
+
state = {k.replace("segmenter._orig_mod.", "segmenter."): v
|
| 303 |
+
for k, v in ckpt["model"].items()}
|
| 304 |
+
model.load_state_dict(state)
|
| 305 |
+
start_step = ckpt.get("step", 0)
|
| 306 |
+
print(f"Resumed from {args.resume} at step {start_step}")
|
| 307 |
+
|
| 308 |
+
betas = tuple(float(x) for x in args.adam_betas.split(","))
|
| 309 |
+
|
| 310 |
+
# Optimizer: AdamW with optional separate conf_head weight decay
|
| 311 |
+
conf_wd = args.conf_head_wd if args.conf_head_wd is not None else args.weight_decay
|
| 312 |
+
if args.conf_head_wd is not None:
|
| 313 |
+
conf_decay_params = []
|
| 314 |
+
other_params = []
|
| 315 |
+
for name, param in model.named_parameters():
|
| 316 |
+
if not param.requires_grad:
|
| 317 |
+
continue
|
| 318 |
+
if 'conf_head' in name:
|
| 319 |
+
conf_decay_params.append(param)
|
| 320 |
+
else:
|
| 321 |
+
other_params.append(param)
|
| 322 |
+
param_groups = [
|
| 323 |
+
{"params": other_params, "weight_decay": args.weight_decay},
|
| 324 |
+
{"params": conf_decay_params, "weight_decay": conf_wd},
|
| 325 |
+
]
|
| 326 |
+
print(f"Conf head WD: {conf_wd} ({len(conf_decay_params)} params)")
|
| 327 |
+
else:
|
| 328 |
+
param_groups = model.parameters()
|
| 329 |
+
|
| 330 |
+
opt = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay,
|
| 331 |
+
betas=betas)
|
| 332 |
+
if args.resume and "optimizer" in ckpt:
|
| 333 |
+
opt.load_state_dict(ckpt["optimizer"])
|
| 334 |
+
|
| 335 |
+
# Data
|
| 336 |
+
torch.manual_seed(args.seed + 7919)
|
| 337 |
+
np.random.seed(args.seed + 7919)
|
| 338 |
+
train_loader = build_loader(args.cache_dir, args.batch_size, aug_rotate=args.aug_rotate,
|
| 339 |
+
aug_jitter=args.aug_jitter, aug_drop=args.aug_drop,
|
| 340 |
+
aug_flip=args.aug_flip)
|
| 341 |
+
val_loader = build_loader(args.val_cache_dir, args.batch_size) if args.val_cache_dir else None
|
| 342 |
+
data_iter = iter(train_loader)
|
| 343 |
+
|
| 344 |
+
# Intervals
|
| 345 |
+
log_int = max(1, min(50, args.steps // 20))
|
| 346 |
+
ckpt_int = 5000
|
| 347 |
+
val_int = ckpt_int if val_loader else 0
|
| 348 |
+
|
| 349 |
+
# Training loop
|
| 350 |
+
global_step = start_step
|
| 351 |
+
loss_ema, loss_sq_ema = 0.0, 0.0
|
| 352 |
+
t_start = time.perf_counter()
|
| 353 |
+
|
| 354 |
+
print(f"Training for {args.steps} steps | {args.segments}seg "
|
| 355 |
+
f"{args.hidden}h {args.latent_tokens}x{args.latent_layers}L "
|
| 356 |
+
f"{args.decoder_layers}D")
|
| 357 |
+
|
| 358 |
+
# Pre-fetch first batch
|
| 359 |
+
try:
|
| 360 |
+
next_batch = next(data_iter)
|
| 361 |
+
except StopIteration:
|
| 362 |
+
data_iter = iter(train_loader)
|
| 363 |
+
next_batch = next(data_iter)
|
| 364 |
+
|
| 365 |
+
# Freeze GC after setup to eliminate stalls during training
|
| 366 |
+
gc.collect()
|
| 367 |
+
gc.freeze()
|
| 368 |
+
gc.disable()
|
| 369 |
+
|
| 370 |
+
amp_ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16,
|
| 371 |
+
enabled=(device.type == 'cuda'))
|
| 372 |
+
|
| 373 |
+
while global_step < args.steps:
|
| 374 |
+
tokens, masks, gt_list, scales, meta = build_tokens(next_batch, model, device)
|
| 375 |
+
|
| 376 |
+
# Epsilon annealing
|
| 377 |
+
if args.sinkhorn_eps_start is not None and args.sinkhorn_eps_start != args.sinkhorn_eps:
|
| 378 |
+
if args.sinkhorn_eps_schedule == "sqrt":
|
| 379 |
+
ratio_sq = (args.sinkhorn_eps_start / args.sinkhorn_eps) ** 2
|
| 380 |
+
t0 = max(args.steps * 0.8 / max(ratio_sq - 1, 1e-6), 1.0)
|
| 381 |
+
current_eps = args.sinkhorn_eps_start / math.sqrt(1 + global_step / t0)
|
| 382 |
+
current_eps = max(current_eps, args.sinkhorn_eps)
|
| 383 |
+
else:
|
| 384 |
+
frac = min(global_step / max(args.steps * 0.8, 1), 1.0)
|
| 385 |
+
current_eps = args.sinkhorn_eps_start + frac * (args.sinkhorn_eps - args.sinkhorn_eps_start)
|
| 386 |
+
else:
|
| 387 |
+
current_eps = args.sinkhorn_eps
|
| 388 |
+
|
| 389 |
+
with amp_ctx:
|
| 390 |
+
out = model.forward_tokens(tokens, masks)
|
| 391 |
+
pred = out["segments"]
|
| 392 |
+
conf = out.get("conf")
|
| 393 |
+
|
| 394 |
+
# Endpoint weight warmup
|
| 395 |
+
if args.endpoint_warmup > 0 and global_step < args.endpoint_warmup:
|
| 396 |
+
current_ep_w = args.endpoint_weight * global_step / args.endpoint_warmup
|
| 397 |
+
else:
|
| 398 |
+
current_ep_w = args.endpoint_weight
|
| 399 |
+
|
| 400 |
+
loss, terms = compute_loss(pred, gt_list, scales.to(device), device,
|
| 401 |
+
args.varifold_weight, args.sinkhorn_weight,
|
| 402 |
+
endpoint_w=current_ep_w,
|
| 403 |
+
conf_logits=conf, conf_weight=args.conf_weight,
|
| 404 |
+
conf_mode=args.conf_mode,
|
| 405 |
+
sinkhorn_eps=current_eps,
|
| 406 |
+
sinkhorn_iters=args.sinkhorn_iters,
|
| 407 |
+
sinkhorn_dustbin=args.sinkhorn_dustbin,
|
| 408 |
+
conf_clamp_min=args.conf_clamp_min)
|
| 409 |
+
|
| 410 |
+
loss_val = loss.item()
|
| 411 |
+
# Adaptive loss spike detection
|
| 412 |
+
if global_step < 100:
|
| 413 |
+
loss_ema = loss_val if global_step == start_step else 0.9 * loss_ema + 0.1 * loss_val
|
| 414 |
+
loss_sq_ema = loss_val**2 if global_step == start_step else 0.9 * loss_sq_ema + 0.1 * loss_val**2
|
| 415 |
+
else:
|
| 416 |
+
loss_ema = 0.99 * loss_ema + 0.01 * loss_val
|
| 417 |
+
loss_sq_ema = 0.99 * loss_sq_ema + 0.01 * loss_val**2
|
| 418 |
+
loss_std = max(math.sqrt(max(loss_sq_ema - loss_ema**2, 0)), 1e-6)
|
| 419 |
+
spike_thresh = loss_ema + 5 * loss_std
|
| 420 |
+
|
| 421 |
+
# Skip on total loss spike or NaN
|
| 422 |
+
if not math.isfinite(loss_val) or loss_val > max(spike_thresh, 0.5):
|
| 423 |
+
sample_ids = [m.get("sample_id", "?") for m in meta]
|
| 424 |
+
skip_reason = f"loss={loss_val:.2f} > thresh={spike_thresh:.2f}"
|
| 425 |
+
print(f"Step {global_step}: {skip_reason}, skipping (samples: {sample_ids[:3]})")
|
| 426 |
+
with open(out_dir / "skipped_samples.jsonl", "a") as f:
|
| 427 |
+
f.write(json.dumps({"step": global_step, "reason": skip_reason,
|
| 428 |
+
"samples": sample_ids}) + "\n")
|
| 429 |
+
try:
|
| 430 |
+
next_batch = next(data_iter)
|
| 431 |
+
except StopIteration:
|
| 432 |
+
data_iter = iter(train_loader)
|
| 433 |
+
next_batch = next(data_iter)
|
| 434 |
+
continue
|
| 435 |
+
|
| 436 |
+
opt.zero_grad()
|
| 437 |
+
loss.backward()
|
| 438 |
+
|
| 439 |
+
# Fetch next batch while GPU finishes backward
|
| 440 |
+
try:
|
| 441 |
+
next_batch = next(data_iter)
|
| 442 |
+
except StopIteration:
|
| 443 |
+
data_iter = iter(train_loader)
|
| 444 |
+
next_batch = next(data_iter)
|
| 445 |
+
|
| 446 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
| 447 |
+
|
| 448 |
+
# LR schedule: warmup -> constant -> optional cooldown or cosine
|
| 449 |
+
if global_step < args.warmup:
|
| 450 |
+
lr = args.lr * (global_step + 1) / max(1, args.warmup)
|
| 451 |
+
elif args.cosine_decay:
|
| 452 |
+
progress = (global_step - args.warmup) / max(1, args.steps - args.warmup)
|
| 453 |
+
lr = args.lr * (0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)))
|
| 454 |
+
elif args.cooldown_start > 0 and global_step >= args.cooldown_start:
|
| 455 |
+
progress = (global_step - args.cooldown_start) / max(1, args.cooldown_steps)
|
| 456 |
+
lr = args.lr * max(0.01, 1.0 - 0.99 * min(1.0, progress))
|
| 457 |
+
else:
|
| 458 |
+
lr = args.lr
|
| 459 |
+
for pg in opt.param_groups:
|
| 460 |
+
pg["lr"] = lr
|
| 461 |
+
opt.step()
|
| 462 |
+
global_step += 1
|
| 463 |
+
|
| 464 |
+
# EMA update
|
| 465 |
+
if ema_model is not None:
|
| 466 |
+
decay = args.ema_decay
|
| 467 |
+
with torch.no_grad():
|
| 468 |
+
for p_ema, p_model in zip(ema_model.parameters(), model.parameters()):
|
| 469 |
+
p_ema.lerp_(p_model, 1.0 - decay)
|
| 470 |
+
|
| 471 |
+
# Log
|
| 472 |
+
entry = {"step": global_step, "ts": time.time(), "loss": loss.item(), "lr": lr}
|
| 473 |
+
entry.update({k: v.item() for k, v in terms.items()})
|
| 474 |
+
if global_step % log_int == 0:
|
| 475 |
+
grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters()
|
| 476 |
+
if p.grad is not None) ** 0.5
|
| 477 |
+
entry["grad_norm"] = grad_norm
|
| 478 |
+
|
| 479 |
+
if global_step % log_int == 0:
|
| 480 |
+
ms = (time.perf_counter() - t_start) / log_int * 1000
|
| 481 |
+
t_start = time.perf_counter()
|
| 482 |
+
t_str = " ".join(f"{k}={v:.4f}" for k, v in terms.items())
|
| 483 |
+
print(f"[{global_step}/{args.steps}] loss={loss.item():.4f} {t_str} "
|
| 484 |
+
f"lr={lr:.2e} gnorm={entry.get('grad_norm', 0):.3f} [{ms:.0f}ms/step]")
|
| 485 |
+
|
| 486 |
+
if val_int > 0 and global_step % val_int == 0:
|
| 487 |
+
try:
|
| 488 |
+
vl_list = []
|
| 489 |
+
with torch.no_grad(), amp_ctx:
|
| 490 |
+
for vb in val_loader:
|
| 491 |
+
vt, vm, vg, vs, _ = build_tokens(vb, model, device)
|
| 492 |
+
vo = model.forward_tokens(vt, vm)
|
| 493 |
+
vl, _ = compute_loss(vo["segments"], vg, vs.to(device), device,
|
| 494 |
+
args.varifold_weight, args.sinkhorn_weight)
|
| 495 |
+
if math.isfinite(vl.item()):
|
| 496 |
+
vl_list.append(vl.item())
|
| 497 |
+
if vl_list:
|
| 498 |
+
val_loss = float(np.mean(vl_list))
|
| 499 |
+
print(f" val_loss={val_loss:.4f}")
|
| 500 |
+
entry["val_loss"] = val_loss
|
| 501 |
+
except Exception as e:
|
| 502 |
+
print(f" val eval failed: {e}")
|
| 503 |
+
|
| 504 |
+
# Write log entry
|
| 505 |
+
with open(out_dir / "history.jsonl", "a") as f:
|
| 506 |
+
f.write(json.dumps(entry) + "\n")
|
| 507 |
+
|
| 508 |
+
if global_step % ckpt_int == 0:
|
| 509 |
+
try:
|
| 510 |
+
gc.enable(); gc.collect(); gc.freeze(); gc.disable()
|
| 511 |
+
torch.cuda.empty_cache()
|
| 512 |
+
save_dict = {"step": global_step, "model": model.state_dict(),
|
| 513 |
+
"optimizer": opt.state_dict(), "args": vars(args)}
|
| 514 |
+
if ema_model is not None:
|
| 515 |
+
save_dict["ema_model"] = ema_model.state_dict()
|
| 516 |
+
torch.save(save_dict, out_dir / "checkpoints" / f"step{global_step:06d}.pt")
|
| 517 |
+
except Exception as e:
|
| 518 |
+
print(f" checkpoint save failed: {e}")
|
| 519 |
+
|
| 520 |
+
# Final save
|
| 521 |
+
save_dict = {"step": global_step, "model": model.state_dict(),
|
| 522 |
+
"optimizer": opt.state_dict(), "args": vars(args)}
|
| 523 |
+
if ema_model is not None:
|
| 524 |
+
save_dict["ema_model"] = ema_model.state_dict()
|
| 525 |
+
torch.save(save_dict, out_dir / "checkpoints" / "final.pt")
|
| 526 |
+
print(f"Done. {global_step} steps. Output: {out_dir}")
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
if __name__ == "__main__":
|
| 530 |
+
main()
|
s23dr_2026_example/varifold.py
CHANGED
|
@@ -1,25 +1,11 @@
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
from .wire_varifold_kernels import (
|
| 4 |
-
loss_semi_lobatto3,
|
| 5 |
-
loss_semi_lobatto3_mix,
|
| 6 |
-
loss_semi_lobatto3_mix_simple,
|
| 7 |
-
loss_simpson3,
|
| 8 |
loss_simpson3_batch,
|
| 9 |
-
loss_simpson3_mix,
|
| 10 |
loss_simpson3_mix_batch,
|
| 11 |
-
loss_simpson3_lenpow,
|
| 12 |
-
loss_simpson3_lenpow_mix,
|
| 13 |
-
loss_semi_legendre,
|
| 14 |
)
|
| 15 |
|
| 16 |
|
| 17 |
-
def edges_to_segments(vertices, edges) -> torch.Tensor:
|
| 18 |
-
verts = torch.as_tensor(vertices, dtype=torch.float32)
|
| 19 |
-
idx = torch.as_tensor(edges, dtype=torch.long)
|
| 20 |
-
return torch.stack([verts[idx[:, 0]], verts[idx[:, 1]]], dim=1)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
def segments_to_vertices_edges(segments: torch.Tensor):
|
| 24 |
segs = torch.as_tensor(segments, dtype=torch.float32)
|
| 25 |
vertices = segs.reshape(-1, 3)
|
|
@@ -27,52 +13,6 @@ def segments_to_vertices_edges(segments: torch.Tensor):
|
|
| 27 |
return vertices, edges
|
| 28 |
|
| 29 |
|
| 30 |
-
def varifold_loss(
|
| 31 |
-
pred_segments: torch.Tensor,
|
| 32 |
-
gt_segments: torch.Tensor,
|
| 33 |
-
sigma: float = 0.1,
|
| 34 |
-
variant: str = "semi_lobatto3",
|
| 35 |
-
t_nodes01: torch.Tensor | None = None,
|
| 36 |
-
t_w: torch.Tensor | None = None,
|
| 37 |
-
sigmas: torch.Tensor | None = None,
|
| 38 |
-
alpha: torch.Tensor | None = None,
|
| 39 |
-
normalize_alpha: bool = True,
|
| 40 |
-
len_pow: float | None = None,
|
| 41 |
-
) -> torch.Tensor:
|
| 42 |
-
p_pred, q_pred = pred_segments[:, 0], pred_segments[:, 1]
|
| 43 |
-
p_gt, q_gt = gt_segments[:, 0], gt_segments[:, 1]
|
| 44 |
-
|
| 45 |
-
if variant == "semi_lobatto3":
|
| 46 |
-
return loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 47 |
-
if variant == "semi_lobatto3_mix":
|
| 48 |
-
if sigmas is None or alpha is None:
|
| 49 |
-
raise ValueError("sigmas and alpha are required for semi_lobatto3_mix")
|
| 50 |
-
return loss_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 51 |
-
if variant == "semi_lobatto3_mix_simple":
|
| 52 |
-
if sigmas is None or alpha is None:
|
| 53 |
-
raise ValueError("sigmas and alpha are required for semi_lobatto3_mix_simple")
|
| 54 |
-
return loss_semi_lobatto3_mix_simple(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 55 |
-
if variant == "simpson3":
|
| 56 |
-
if sigmas is not None or alpha is not None:
|
| 57 |
-
if sigmas is None or alpha is None:
|
| 58 |
-
raise ValueError("sigmas and alpha are required for simpson3 mix")
|
| 59 |
-
return loss_simpson3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 60 |
-
return loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 61 |
-
if variant == "simpson3_lenpow":
|
| 62 |
-
if len_pow is None:
|
| 63 |
-
len_pow = 1.0
|
| 64 |
-
if sigmas is not None or alpha is not None:
|
| 65 |
-
if sigmas is None or alpha is None:
|
| 66 |
-
raise ValueError("sigmas and alpha are required for simpson3_lenpow mix")
|
| 67 |
-
return loss_simpson3_lenpow_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, len_pow, normalize_alpha)
|
| 68 |
-
return loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
|
| 69 |
-
if variant == "semi_legendre":
|
| 70 |
-
return loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
|
| 71 |
-
if variant in ("centers", "segments_varifold", "semi_lobatto1"):
|
| 72 |
-
return varifold_loss_centers(pred_segments, gt_segments, sigma)
|
| 73 |
-
raise ValueError(f"Unknown varifold variant: {variant}")
|
| 74 |
-
|
| 75 |
-
|
| 76 |
def varifold_loss_batch(
|
| 77 |
pred_segments: torch.Tensor,
|
| 78 |
gt_segments: torch.Tensor,
|
|
@@ -102,95 +42,12 @@ def varifold_loss_batch(
|
|
| 102 |
if pred_weights is not None:
|
| 103 |
w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
|
| 104 |
|
| 105 |
-
if variant =
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
sigmas_t = None
|
| 115 |
-
if sigmas is not None:
|
| 116 |
-
sigmas_t = torch.as_tensor(sigmas, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 117 |
-
for idx in range(pred_segments.shape[0]):
|
| 118 |
-
gt_b = gt_segments[idx]
|
| 119 |
-
if gt_mask is not None:
|
| 120 |
-
gt_b = gt_b[gt_mask[idx]]
|
| 121 |
-
sigmas_i = sigmas
|
| 122 |
-
if sigmas_t is not None and sigmas_t.ndim == 2:
|
| 123 |
-
sigmas_i = sigmas_t[idx]
|
| 124 |
-
losses.append(
|
| 125 |
-
varifold_loss(
|
| 126 |
-
pred_segments[idx],
|
| 127 |
-
gt_b,
|
| 128 |
-
sigma=sigma,
|
| 129 |
-
variant=variant,
|
| 130 |
-
t_nodes01=t_nodes01,
|
| 131 |
-
t_w=t_w,
|
| 132 |
-
sigmas=sigmas_i,
|
| 133 |
-
alpha=alpha,
|
| 134 |
-
normalize_alpha=normalize_alpha,
|
| 135 |
-
len_pow=len_pow,
|
| 136 |
-
)
|
| 137 |
-
)
|
| 138 |
-
return torch.stack(losses, dim=0)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def varifold_loss_centers(
|
| 142 |
-
pred_segments: torch.Tensor,
|
| 143 |
-
gt_segments: torch.Tensor,
|
| 144 |
-
sigma: float = 0.1,
|
| 145 |
-
normalize_weights: bool = True,
|
| 146 |
-
) -> torch.Tensor:
|
| 147 |
-
eps = 1e-8
|
| 148 |
-
a_p, b_p = pred_segments[:, 0], pred_segments[:, 1]
|
| 149 |
-
a_g, b_g = gt_segments[:, 0], gt_segments[:, 1]
|
| 150 |
-
|
| 151 |
-
v_p = b_p - a_p
|
| 152 |
-
v_g = b_g - a_g
|
| 153 |
-
len_p = torch.linalg.norm(v_p, dim=-1)
|
| 154 |
-
len_g = torch.linalg.norm(v_g, dim=-1)
|
| 155 |
-
|
| 156 |
-
x_p = 0.5 * (a_p + b_p)
|
| 157 |
-
x_g = 0.5 * (a_g + b_g)
|
| 158 |
-
|
| 159 |
-
u_p = v_p / (len_p[:, None] + eps)
|
| 160 |
-
u_g = v_g / (len_g[:, None] + eps)
|
| 161 |
-
|
| 162 |
-
w_p = len_p
|
| 163 |
-
w_g = len_g
|
| 164 |
-
if normalize_weights:
|
| 165 |
-
w_p = w_p / (w_p.sum() + eps)
|
| 166 |
-
w_g = w_g / (w_g.sum() + eps)
|
| 167 |
-
|
| 168 |
-
diff_pp = x_p[:, None, :] - x_p[None, :, :]
|
| 169 |
-
diff_gg = x_g[:, None, :] - x_g[None, :, :]
|
| 170 |
-
diff_pg = x_p[:, None, :] - x_g[None, :, :]
|
| 171 |
-
d_pp = (diff_pp * diff_pp).sum(dim=-1)
|
| 172 |
-
d_gg = (diff_gg * diff_gg).sum(dim=-1)
|
| 173 |
-
d_pg = (diff_pg * diff_pg).sum(dim=-1)
|
| 174 |
-
|
| 175 |
-
inv2s2 = 1.0 / (2.0 * sigma * sigma)
|
| 176 |
-
k_pp = torch.exp(-d_pp * inv2s2)
|
| 177 |
-
k_gg = torch.exp(-d_gg * inv2s2)
|
| 178 |
-
k_pg = torch.exp(-d_pg * inv2s2)
|
| 179 |
-
|
| 180 |
-
dot_pp = (u_p[:, None, :] * u_p[None, :, :]).sum(dim=-1)
|
| 181 |
-
dot_gg = (u_g[:, None, :] * u_g[None, :, :]).sum(dim=-1)
|
| 182 |
-
dot_pg = (u_p[:, None, :] * u_g[None, :, :]).sum(dim=-1)
|
| 183 |
-
|
| 184 |
-
k_pp = k_pp * (dot_pp * dot_pp)
|
| 185 |
-
k_gg = k_gg * (dot_gg * dot_gg)
|
| 186 |
-
k_pg = k_pg * (dot_pg * dot_pg)
|
| 187 |
-
|
| 188 |
-
wp_row = w_p[:, None]
|
| 189 |
-
wp_col = w_p[None, :]
|
| 190 |
-
wg_row = w_g[:, None]
|
| 191 |
-
wg_col = w_g[None, :]
|
| 192 |
-
|
| 193 |
-
a_pp = (wp_row * wp_col * k_pp).sum(dim=-1).sum(dim=-1)
|
| 194 |
-
a_gg = (wg_row * wg_col * k_gg).sum(dim=-1).sum(dim=-1)
|
| 195 |
-
a_pg = (w_p[:, None] * w_g[None, :] * k_pg).sum(dim=-1).sum(dim=-1)
|
| 196 |
-
return a_pp + a_gg - 2.0 * a_pg
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
from .wire_varifold_kernels import (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
loss_simpson3_batch,
|
|
|
|
| 5 |
loss_simpson3_mix_batch,
|
|
|
|
|
|
|
|
|
|
| 6 |
)
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def segments_to_vertices_edges(segments: torch.Tensor):
|
| 10 |
segs = torch.as_tensor(segments, dtype=torch.float32)
|
| 11 |
vertices = segs.reshape(-1, 3)
|
|
|
|
| 13 |
return vertices, edges
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def varifold_loss_batch(
|
| 17 |
pred_segments: torch.Tensor,
|
| 18 |
gt_segments: torch.Tensor,
|
|
|
|
| 42 |
if pred_weights is not None:
|
| 43 |
w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
|
| 44 |
|
| 45 |
+
if variant != "simpson3":
|
| 46 |
+
raise ValueError(
|
| 47 |
+
f"Unsupported varifold variant: {variant!r}. "
|
| 48 |
+
f"Only 'simpson3' is supported in batch mode.")
|
| 49 |
+
if sigmas is not None or alpha is not None:
|
| 50 |
+
if sigmas is None or alpha is None:
|
| 51 |
+
raise ValueError("sigmas and alpha are required for simpson3 mix")
|
| 52 |
+
return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
|
| 53 |
+
return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s23dr_2026_example/wire_varifold_kernels.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import math
|
| 2 |
import torch
|
| 3 |
|
| 4 |
# -----------------------------
|
|
@@ -46,7 +45,7 @@ def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool):
|
|
| 46 |
return sigmas_t, alpha_t
|
| 47 |
|
| 48 |
# -----------------------------
|
| 49 |
-
#
|
| 50 |
# -----------------------------
|
| 51 |
def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
| 52 |
if w is None:
|
|
@@ -121,227 +120,9 @@ def cross_simpson3(
|
|
| 121 |
return out[0] if not batched else out
|
| 122 |
|
| 123 |
|
| 124 |
-
def cross_simpson3_lenpow(
|
| 125 |
-
pA,
|
| 126 |
-
qA,
|
| 127 |
-
pB,
|
| 128 |
-
qB,
|
| 129 |
-
sigma: float | torch.Tensor,
|
| 130 |
-
len_pow: float,
|
| 131 |
-
wA: torch.Tensor | None = None,
|
| 132 |
-
wB: torch.Tensor | None = None,
|
| 133 |
-
):
|
| 134 |
-
device, dtype = pA.device, pA.dtype
|
| 135 |
-
batched = pA.dim() == 3
|
| 136 |
-
if not batched:
|
| 137 |
-
pA = pA.unsqueeze(0)
|
| 138 |
-
qA = qA.unsqueeze(0)
|
| 139 |
-
pB = pB.unsqueeze(0)
|
| 140 |
-
qB = qB.unsqueeze(0)
|
| 141 |
-
nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 142 |
-
w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
|
| 143 |
-
|
| 144 |
-
bsz, nA, _ = pA.shape
|
| 145 |
-
nB = pB.shape[1]
|
| 146 |
-
wA = _prep_weight(wA, nA, bsz, device, dtype)
|
| 147 |
-
wB = _prep_weight(wB, nB, bsz, device, dtype)
|
| 148 |
-
|
| 149 |
-
_, _, ellA, uA = segment_geom(pA, qA)
|
| 150 |
-
_, _, ellB, uB = segment_geom(pB, qB)
|
| 151 |
-
|
| 152 |
-
XA = sample_points(pA, qA, nodes) # (B,N,3,3)
|
| 153 |
-
YB = sample_points(pB, qB, nodes) # (B,M,3,3)
|
| 154 |
-
|
| 155 |
-
ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
|
| 156 |
-
lenfac = (ellA[:, :, None] * ellB[:, None, :]).pow(len_pow)
|
| 157 |
-
if wA is not None or wB is not None:
|
| 158 |
-
if wA is None:
|
| 159 |
-
wA = torch.ones((bsz, nA), device=device, dtype=dtype)
|
| 160 |
-
if wB is None:
|
| 161 |
-
wB = torch.ones((bsz, nB), device=device, dtype=dtype)
|
| 162 |
-
lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
|
| 163 |
-
|
| 164 |
-
diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
|
| 165 |
-
r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
|
| 166 |
-
sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
|
| 167 |
-
if sigma_t.ndim == 0:
|
| 168 |
-
inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
|
| 169 |
-
else:
|
| 170 |
-
if sigma_t.shape[0] != bsz:
|
| 171 |
-
raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
|
| 172 |
-
inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
|
| 173 |
-
K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
|
| 174 |
-
|
| 175 |
-
spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
|
| 176 |
-
out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
|
| 177 |
-
return out[0] if not batched else out
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
# -----------------------------
|
| 181 |
-
# 2/3) Semi-analytic in s, quadrature in t
|
| 182 |
-
# - Lobatto-3 (endpoints+midpoint)
|
| 183 |
-
# - Gauss-Legendre Q (nodes/weights passed in)
|
| 184 |
-
# -----------------------------
|
| 185 |
-
def cross_semi_analytic(pA, qA, pB, qB, sigma: float, t_nodes01: torch.Tensor, t_w: torch.Tensor):
|
| 186 |
-
"""
|
| 187 |
-
Gaussian k_x. Integrate s exactly along A, integrate t numerically along B.
|
| 188 |
-
t_nodes01, t_w: (Q,) nodes/weights on [0,1] (constants you pass in)
|
| 189 |
-
"""
|
| 190 |
-
device, dtype = pA.device, pA.dtype
|
| 191 |
-
t = t_nodes01.to(device=device, dtype=dtype) # (Q,)
|
| 192 |
-
w = t_w.to(device=device, dtype=dtype) # (Q,)
|
| 193 |
-
|
| 194 |
-
dA, aA, ellA, uA = segment_geom(pA, qA)
|
| 195 |
-
dB, _, ellB, uB = segment_geom(pB, qB)
|
| 196 |
-
|
| 197 |
-
# (N,M) factors
|
| 198 |
-
ang = (uA @ uB.t()).pow(2)
|
| 199 |
-
lenfac = ellA[:, None] * ellB[None, :]
|
| 200 |
-
|
| 201 |
-
# r0: (N,M,3)
|
| 202 |
-
r0 = pA[:, None, :] - pB[None, :, :]
|
| 203 |
-
|
| 204 |
-
# r(t): (N,M,Q,3)
|
| 205 |
-
r = r0[:, :, None, :] - t[None, None, :, None] * dB[None, :, None, :]
|
| 206 |
-
|
| 207 |
-
# beta, r2: (N,M,Q)
|
| 208 |
-
beta = (r * dA[:, None, None, :]).sum(dim=-1)
|
| 209 |
-
r2 = (r * r).sum(dim=-1)
|
| 210 |
-
|
| 211 |
-
# semi-analytic constants per A segment: shapes broadcast to (N,1,1)
|
| 212 |
-
a = aA.clamp_min(1e-12)
|
| 213 |
-
inv_a = (1.0 / a).view(-1, 1, 1)
|
| 214 |
-
denom = (torch.sqrt(2.0 * a) * sigma).view(-1, 1, 1)
|
| 215 |
-
pref = (math.sqrt(math.pi) * sigma / torch.sqrt(2.0 * a)).view(-1, 1, 1)
|
| 216 |
-
|
| 217 |
-
# J(t): (N,M,Q)
|
| 218 |
-
exp_term = torch.exp(-(r2 - (beta * beta) * inv_a) / (2.0 * sigma * sigma))
|
| 219 |
-
erf1 = torch.special.erf((a.view(-1, 1, 1) + beta) / denom)
|
| 220 |
-
erf0 = torch.special.erf(beta / denom)
|
| 221 |
-
J = pref * (erf1 - erf0) * exp_term
|
| 222 |
-
|
| 223 |
-
# integrate over t: (N,M)
|
| 224 |
-
spatial = (J * w.view(1, 1, -1)).sum(dim=-1)
|
| 225 |
-
return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def cross_semi_lobatto3(pA, qA, pB, qB, sigma: float):
|
| 229 |
-
device, dtype = pA.device, pA.dtype
|
| 230 |
-
t = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 231 |
-
w = LOBATTO3_W.to(device=device, dtype=dtype)
|
| 232 |
-
return cross_semi_analytic(pA, qA, pB, qB, sigma, t, w)
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def cross_semi_lobatto3_mix(
|
| 236 |
-
pA,
|
| 237 |
-
qA,
|
| 238 |
-
pB,
|
| 239 |
-
qB,
|
| 240 |
-
sigmas,
|
| 241 |
-
alpha,
|
| 242 |
-
normalize_alpha: bool = True,
|
| 243 |
-
):
|
| 244 |
-
"""
|
| 245 |
-
Semi-analytic in s (along A), Lobatto-3 in t (along B), with a sigma mixture.
|
| 246 |
-
"""
|
| 247 |
-
device, dtype = pA.device, pA.dtype
|
| 248 |
-
t_nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 249 |
-
t_w = LOBATTO3_W.to(device=device, dtype=dtype)
|
| 250 |
-
|
| 251 |
-
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 252 |
-
|
| 253 |
-
dA, aA, ellA, uA = segment_geom(pA, qA)
|
| 254 |
-
dB, _, ellB, uB = segment_geom(pB, qB)
|
| 255 |
-
|
| 256 |
-
ang = (uA @ uB.t()).pow(2)
|
| 257 |
-
lenfac = ellA[:, None] * ellB[None, :]
|
| 258 |
-
|
| 259 |
-
r0 = pA[:, None, :] - pB[None, :, :]
|
| 260 |
-
|
| 261 |
-
a = aA.clamp_min(1e-12)
|
| 262 |
-
inv_a = (1.0 / a).view(-1, 1)
|
| 263 |
-
sqrt_a = torch.sqrt(2.0 * a).clamp_min(1e-12)
|
| 264 |
-
|
| 265 |
-
denom = (sqrt_a[:, None] * sigmas_t[None, :]).clamp_min(1e-12)
|
| 266 |
-
pref = math.sqrt(math.pi) * sigmas_t[None, :] / sqrt_a[:, None]
|
| 267 |
-
inv2s2 = (1.0 / (2.0 * sigmas_t * sigmas_t)).view(1, 1, -1)
|
| 268 |
-
|
| 269 |
-
denom_nmS = denom[:, None, :]
|
| 270 |
-
pref_nmS = pref[:, None, :]
|
| 271 |
-
alpha_nmS = alpha_t.view(1, 1, -1)
|
| 272 |
-
a_nm1 = a[:, None, None]
|
| 273 |
-
|
| 274 |
-
spatial = torch.zeros((pA.shape[0], pB.shape[0]), device=device, dtype=dtype)
|
| 275 |
-
for tk, wk in zip(t_nodes, t_w):
|
| 276 |
-
r = r0 - tk * dB[None, :, :]
|
| 277 |
-
beta = (r * dA[:, None, :]).sum(dim=-1)
|
| 278 |
-
r2 = (r * r).sum(dim=-1)
|
| 279 |
-
core = r2 - (beta * beta) * inv_a
|
| 280 |
-
|
| 281 |
-
exp_term = torch.exp(-core[:, :, None] * inv2s2)
|
| 282 |
-
erf1 = torch.special.erf((a_nm1 + beta[:, :, None]) / denom_nmS)
|
| 283 |
-
erf0 = torch.special.erf(beta[:, :, None] / denom_nmS)
|
| 284 |
-
J = pref_nmS * (erf1 - erf0) * exp_term
|
| 285 |
-
spatial = spatial + wk * (J * alpha_nmS).sum(dim=-1)
|
| 286 |
-
|
| 287 |
-
return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
# -----------------------------
|
| 291 |
-
#
|
| 292 |
# -----------------------------
|
| 293 |
-
# def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
|
| 294 |
-
# s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 295 |
-
# # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
|
| 296 |
-
# cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 297 |
-
# # return s_pred + s_gt - 2.0 * cross
|
| 298 |
-
# return s_pred - 2.0 * cross
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
|
| 302 |
-
s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 303 |
-
# s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
|
| 304 |
-
cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 305 |
-
# return s_pred + s_gt - 2.0 * cross
|
| 306 |
-
return s_pred - 2.0 * cross
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma: float, len_pow: float):
|
| 310 |
-
s_pred = cross_simpson3_lenpow(p_pred, q_pred, p_pred, q_pred, sigma, len_pow)
|
| 311 |
-
# s_gt = cross_simpson3_lenpow(p_gt, q_gt, p_gt, q_gt, sigma, len_pow)
|
| 312 |
-
cross = cross_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
|
| 313 |
-
# return s_pred + s_gt - 2.0 * cross
|
| 314 |
-
return s_pred - 2.0 * cross
|
| 315 |
-
|
| 316 |
-
def loss_simpson3_mix(
|
| 317 |
-
p_pred,
|
| 318 |
-
q_pred,
|
| 319 |
-
p_gt,
|
| 320 |
-
q_gt,
|
| 321 |
-
sigmas,
|
| 322 |
-
alpha,
|
| 323 |
-
normalize_alpha: bool = True,
|
| 324 |
-
):
|
| 325 |
-
device, dtype = p_pred.device, p_pred.dtype
|
| 326 |
-
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 327 |
-
losses = [loss_simpson3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
|
| 328 |
-
return (torch.stack(losses) * alpha_t).sum()
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
# def loss_simpson3_batch(
|
| 332 |
-
# p_pred: torch.Tensor,
|
| 333 |
-
# q_pred: torch.Tensor,
|
| 334 |
-
# p_gt: torch.Tensor,
|
| 335 |
-
# q_gt: torch.Tensor,
|
| 336 |
-
# sigma: float | torch.Tensor,
|
| 337 |
-
# w_gt: torch.Tensor | None = None,
|
| 338 |
-
# ) -> torch.Tensor:
|
| 339 |
-
# s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 340 |
-
# # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma, wA=w_gt, wB=w_gt)
|
| 341 |
-
# cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wB=w_gt)
|
| 342 |
-
# # return s_pred + s_gt - 2.0 * cross
|
| 343 |
-
# return s_pred - 2.0 * cross
|
| 344 |
-
|
| 345 |
|
| 346 |
def loss_simpson3_batch(
|
| 347 |
p_pred: torch.Tensor,
|
|
@@ -385,77 +166,3 @@ def loss_simpson3_mix_batch(
|
|
| 385 |
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
|
| 386 |
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
|
| 387 |
raise ValueError("sigmas must be 1D or 2D for batch loss")
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
def loss_simpson3_lenpow_mix(
|
| 391 |
-
p_pred,
|
| 392 |
-
q_pred,
|
| 393 |
-
p_gt,
|
| 394 |
-
q_gt,
|
| 395 |
-
sigmas,
|
| 396 |
-
alpha,
|
| 397 |
-
len_pow: float,
|
| 398 |
-
normalize_alpha: bool = True,
|
| 399 |
-
):
|
| 400 |
-
device, dtype = p_pred.device, p_pred.dtype
|
| 401 |
-
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 402 |
-
losses = [loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, s, len_pow) for s in sigmas_t]
|
| 403 |
-
return (torch.stack(losses) * alpha_t).sum()
|
| 404 |
-
|
| 405 |
-
def loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma: float):
|
| 406 |
-
s_pred = cross_semi_lobatto3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 407 |
-
# s_gt = cross_semi_lobatto3(p_gt, q_gt, p_gt, q_gt, sigma)
|
| 408 |
-
cross = cross_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 409 |
-
# return s_pred + s_gt - 2.0 * cross
|
| 410 |
-
return s_pred - 2.0 * cross
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
def loss_semi_lobatto3_mix(
|
| 414 |
-
p_pred,
|
| 415 |
-
q_pred,
|
| 416 |
-
p_gt,
|
| 417 |
-
q_gt,
|
| 418 |
-
sigmas,
|
| 419 |
-
alpha,
|
| 420 |
-
normalize_alpha: bool = True,
|
| 421 |
-
):
|
| 422 |
-
s_pred = cross_semi_lobatto3_mix(p_pred, q_pred, p_pred, q_pred, sigmas, alpha, normalize_alpha)
|
| 423 |
-
# s_gt = cross_semi_lobatto3_mix(p_gt, q_gt, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 424 |
-
cross = cross_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 425 |
-
# return s_pred + s_gt - 2.0 * cross
|
| 426 |
-
return s_pred - 2.0 * cross
|
| 427 |
-
|
| 428 |
-
def loss_semi_lobatto3_mix_simple(
|
| 429 |
-
p_pred,
|
| 430 |
-
q_pred,
|
| 431 |
-
p_gt,
|
| 432 |
-
q_gt,
|
| 433 |
-
sigmas,
|
| 434 |
-
alpha,
|
| 435 |
-
normalize_alpha: bool = True,
|
| 436 |
-
):
|
| 437 |
-
device, dtype = p_pred.device, p_pred.dtype
|
| 438 |
-
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 439 |
-
losses = [loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
|
| 440 |
-
return (torch.stack(losses) * alpha_t).sum()
|
| 441 |
-
|
| 442 |
-
def loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma: float, t_nodes01, t_w):
|
| 443 |
-
s_pred = cross_semi_analytic(p_pred, q_pred, p_pred, q_pred, sigma, t_nodes01, t_w)
|
| 444 |
-
s_gt = cross_semi_analytic(p_gt, q_gt, p_gt, q_gt, sigma, t_nodes01, t_w)
|
| 445 |
-
cross = cross_semi_analytic(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
|
| 446 |
-
return s_pred + s_gt - 2.0 * cross
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
# -----------------------------
|
| 450 |
-
# torch.compile usage
|
| 451 |
-
# -----------------------------
|
| 452 |
-
# For Legendre: generate nodes/weights ONCE outside compile and pass them in.
|
| 453 |
-
# Example:
|
| 454 |
-
# import numpy as np
|
| 455 |
-
# x,w = np.polynomial.legendre.leggauss(Q)
|
| 456 |
-
# t_nodes = torch.tensor(0.5*(x+1.0), device=device, dtype=dtype)
|
| 457 |
-
# t_w = torch.tensor(0.5*w, device=device, dtype=dtype)
|
| 458 |
-
#
|
| 459 |
-
# compiled_loss = torch.compile(loss_semi_lobatto3, fullgraph=True)
|
| 460 |
-
# compiled_loss_leg = torch.compile(lambda pp,qp,pg,qg,s: loss_semi_legendre(pp,qp,pg,qg,s,t_nodes,t_w),
|
| 461 |
-
# fullgraph=True)
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
# -----------------------------
|
|
|
|
| 45 |
return sigmas_t, alpha_t
|
| 46 |
|
| 47 |
# -----------------------------
|
| 48 |
+
# Simpson-3 on both segments (3x3 product rule)
|
| 49 |
# -----------------------------
|
| 50 |
def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
| 51 |
if w is None:
|
|
|
|
| 120 |
return out[0] if not batched else out
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
# -----------------------------
|
| 124 |
+
# Batch losses
|
| 125 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def loss_simpson3_batch(
|
| 128 |
p_pred: torch.Tensor,
|
|
|
|
| 166 |
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
|
| 167 |
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
|
| 168 |
raise ValueError("sigmas must be 1D or 2D for batch loss")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
script.py
CHANGED
|
@@ -34,14 +34,14 @@ from s23dr_2026_example.make_sampled_cache import _priority_sample
|
|
| 34 |
# Tokenizer / model imports
|
| 35 |
from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
|
| 36 |
from s23dr_2026_example.model import EdgeDepthSegmentsModel
|
| 37 |
-
from s23dr_2026_example.segment_postprocess import
|
| 38 |
from s23dr_2026_example.varifold import segments_to_vertices_edges
|
| 39 |
from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
|
| 40 |
|
| 41 |
-
SEQ_LEN =
|
| 42 |
-
COLMAP_QUOTA =
|
| 43 |
-
DEPTH_QUOTA =
|
| 44 |
-
CONF_THRESH = 0.
|
| 45 |
MERGE_THRESH = 0.4
|
| 46 |
SNAP_RADIUS = 0.5
|
| 47 |
|
|
|
|
| 34 |
# Tokenizer / model imports
|
| 35 |
from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
|
| 36 |
from s23dr_2026_example.model import EdgeDepthSegmentsModel
|
| 37 |
+
from s23dr_2026_example.segment_postprocess import merge_vertices_iterative
|
| 38 |
from s23dr_2026_example.varifold import segments_to_vertices_edges
|
| 39 |
from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
|
| 40 |
|
| 41 |
+
SEQ_LEN = 4096
|
| 42 |
+
COLMAP_QUOTA = 3072
|
| 43 |
+
DEPTH_QUOTA = 1024
|
| 44 |
+
CONF_THRESH = 0.5
|
| 45 |
MERGE_THRESH = 0.4
|
| 46 |
SNAP_RADIUS = 0.5
|
| 47 |
|