Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
jacklangerman commited on
Commit
0f31e57
·
1 Parent(s): 465f2c6

4096-release (#1)

Browse files

- Release: S23DR 2026 learned baseline (HSS=0.382) (43975eb3f3e3ba3f9773ef8afda9a3d0e85b2f9c)

Files changed (37) hide show
  1. .gitignore +5 -0
  2. REPRODUCE.md +194 -0
  3. checkpoint.pt +2 -2
  4. configs/base.json +39 -0
  5. repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_args.json +66 -0
  6. repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_final.pt +3 -0
  7. repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_args.json +66 -0
  8. repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_final.pt +3 -0
  9. repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_args.json +66 -0
  10. repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_final.pt +3 -0
  11. repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_args.json +66 -0
  12. repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_final.pt +3 -0
  13. repro_runs/deterministic_hss372/20260330_071030_8c95_3610_args.json +66 -0
  14. repro_runs/deterministic_hss372/20260330_071030_8c95_3610_final.pt +3 -0
  15. repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_args.json +66 -0
  16. repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_final.pt +3 -0
  17. repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_args.json +66 -0
  18. repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_final.pt +3 -0
  19. repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_args.json +66 -0
  20. repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_final.pt +3 -0
  21. repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_args.json +66 -0
  22. repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_final.pt +3 -0
  23. reproduce.sh +68 -0
  24. reproduce_deterministic.sh +71 -0
  25. s23dr_2026_example/attention.py +0 -85
  26. s23dr_2026_example/cache_scenes.py +0 -195
  27. s23dr_2026_example/color_mappings.py +0 -26
  28. s23dr_2026_example/data.py +3 -13
  29. s23dr_2026_example/losses.py +10 -106
  30. s23dr_2026_example/make_sampled_cache.py +0 -185
  31. s23dr_2026_example/model.py +4 -181
  32. s23dr_2026_example/sinkhorn.py +0 -55
  33. s23dr_2026_example/soft_hss_loss.py +0 -507
  34. s23dr_2026_example/train.py +530 -0
  35. s23dr_2026_example/varifold.py +9 -152
  36. s23dr_2026_example/wire_varifold_kernels.py +2 -295
  37. script.py +5 -5
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ runs/
4
+ *.png
5
+ *.log
REPRODUCE.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reproducing the Best Checkpoint (HSS=0.382)
2
+
3
+ ## Quick Start
4
+
5
+ The `checkpoint.pt` in this repo is the final model. To run inference:
6
+
7
+ ```bash
8
+ python script.py
9
+ ```
10
+
11
+ To reproduce from scratch (~3hr on 1x RTX 4090):
12
+
13
+ ```bash
14
+ bash reproduce.sh
15
+ ```
16
+
17
+ ## Exact Recipe
18
+
19
+ Architecture (unchanged across all 3 steps):
20
+ ```
21
+ Perceiver: hidden=256, ff=1024, latent_tokens=256, latent_layers=7
22
+ encoder_layers=4, decoder_layers=3, cross_attn_interval=4
23
+ num_heads=4, kv_heads_cross=2, kv_heads_self=2
24
+ qk_norm=True (L2), rms_norm=True, dropout=0.1
25
+ segments=64, segment_param=midpoint_dir_len, segment_conf=True
26
+ behind_emb_dim=8, vote_features=True, activation=gelu
27
+ ```
28
+
29
+ All shared config lives in `configs/base.json`.
30
+
31
+ ### Step 1: 2048 Phase 1 (from scratch) — ~1.5hr
32
+
33
+ ```
34
+ Data: hf://usm3d/s23dr-2026-sampled_2048_v2:train (16,508 samples)
35
+ Steps: 0 -> 125,000 (242 epochs)
36
+ LR: 3e-4, warmup=10,000
37
+ Batch size: 32
38
+ Optimizer: AdamW, betas=(0.9, 0.95), weight_decay=0.01
39
+ Sinkhorn: eps=0.1, iters=20, dustbin=0.3
40
+ Conf: weight=0.1, mode=sinkhorn, head_wd=0.1
41
+ Endpoint: OFF
42
+ Aug: rotate=True, flip=True
43
+ Seed: 353
44
+ ```
45
+
46
+ Trains the perceiver from random init on 2048-point samples. The sinkhorn
47
+ optimal transport loss learns to match predicted segments to ground truth.
48
+
49
+ **Why 2048 first:** Training directly on 4096 overfits (1.47x train/val ratio
50
+ vs 1.19x for 2048). The 2048 model learns better-generalized representations.
51
+
52
+ **Output:** HSS ~0.28.
53
+
54
+ ### Step 2: 4096 finetune (constant LR) — ~15min
55
+
56
+ ```
57
+ Resume: Step 1 -> step125000.pt
58
+ Data: hf://usm3d/s23dr-2026-sampled_4096_v2:train (15,892 samples)
59
+ Steps: 125,001 -> 135,000 (10k steps)
60
+ LR: 3e-5 (constant, no cooldown)
61
+ Batch size: 64
62
+ Endpoint: OFF
63
+ ```
64
+
65
+ Switches input from 2048 to 4096 points, increasing structural coverage from
66
+ 66% to 74%. The gentle lr (3e-5) preserves learned representations while
67
+ adapting to the extra input. Higher LR (>1e-4) causes catastrophic forgetting.
68
+
69
+ HSS jumps from 0.28 to 0.35 in ~5k steps. Plateaus by 10k steps.
70
+
71
+ **Output:** HSS ~0.35.
72
+
73
+ ### Step 3: Cooldown with endpoint loss — ~1hr
74
+
75
+ ```
76
+ Resume: Step 2 -> step135000.pt
77
+ Data: hf://usm3d/s23dr-2026-sampled_4096_v2:train
78
+ Steps: 135,001 -> 170,000 (35k steps)
79
+ LR: 3e-5, cooldown_start=150,000, cooldown_steps=20,000
80
+ (constant 3e-5 for 15k steps, then linear decay to ~0 over 20k)
81
+ Batch size: 64
82
+ Endpoint: weight=0.1
83
+ ```
84
+
85
+ Adds symmetric endpoint L1 loss (using detached sinkhorn assignment) to
86
+ tighten vertex precision. The sinkhorn loss alone operates on segment
87
+ midpoint/direction/length and doesn't directly penalize endpoint position error.
88
+
89
+ **Output:** HSS=0.382, F1=0.414.
90
+
91
+ ### Key Numbers
92
+
93
+ | Stage | Steps | HSS | F1 | What changed |
94
+ |-------|-------|-----|-----|-------------|
95
+ | After Step 1 | 125k | 0.281 | 0.156 | Learned geometry from 2048 pts |
96
+ | After Step 2 | 135k | 0.351 | 0.190 | +74% coverage from 4096 pts |
97
+ | After Step 3 | 170k | **0.382** | **0.411** | Vertex precision from endpoint loss |
98
+
99
+ ## Why This Works
100
+
101
+ 1. **2048 training has low overfitting** (1.19x train/val ratio) — the model
102
+ learns good representations without memorizing training samples.
103
+
104
+ 2. **4096 data has higher coverage ceiling** (74% vs 66% structural points) —
105
+ more of the building surface is observed, improving vertex recall.
106
+
107
+ 3. **Gentle finetuning preserves representations** — at lr=3e-5, the model
108
+ keeps its learned geometry understanding while adapting to the extra input.
109
+
110
+ 4. **Endpoint loss tightens vertices** — the symmetric endpoint distance
111
+ directly penalizes vertex position errors, which sinkhorn loss alone
112
+ doesn't do (it operates on midpoint/direction/length parametrization).
113
+
114
+ ## What Doesn't Work
115
+
116
+ - **Training 4096 from scratch:** overfits (1.47x train/val gap), peaks at 0.346
117
+ - **BuildingWorld pretraining:** representations are orthogonal to S23DR (cosine sim = 0.05)
118
+ - **Mixed BW+S23DR training:** BW data hurts due to domain gap
119
+ - **High dropout / weight decay:** prevents overfitting but causes underfitting
120
+ - **High finetune LR (>1e-4):** catastrophic forgetting of 2048 representations
121
+ - **Steeper cooldown (1e-5, 20x drop):** slightly worse than 3e-5 for this checkpoint
122
+
123
+ ## Reproduction Results
124
+
125
+ ### End-to-end reproductions
126
+
127
+ | Model | HSS | F1 | IoU | Notes |
128
+ |-------|-----|-----|-----|-------|
129
+ | Original | 0.382 | 0.414 | 0.370 | Shipped checkpoint |
130
+ | E2E repro #4 | 0.379 | 0.409 | 0.369 | Closest E2E, `repro_runs/e2e_repro4_hss379/` |
131
+ | Compiled repro (from submission codebase) | 0.376 | — | — | Best compiled repro from this codebase, `repro_runs/compiled_repro_hss376/` |
132
+ | E2E repro #3 | 0.375 | 0.404 | 0.367 | |
133
+ | Deterministic E2E | 0.372 | 0.398 | 0.368 | Bit-reproducible, `repro_runs/deterministic_hss372/` |
134
+ | E2E repro #5 | 0.349 | 0.373 | — | Outlier (early compile divergence) |
135
+
136
+ ### Partial reproductions (isolating pipeline stages)
137
+
138
+ | Test | Starting from | HSS | Gap to original |
139
+ |------|--------------|-----|-----------------|
140
+ | Step 3 from orig Step 2 (run A) | Original step135000.pt | 0.382 | 0.000 |
141
+ | Step 3 from orig Step 2 (run B) | Original step135000.pt | 0.384 | +0.002 |
142
+ | Step 2+3 from orig Step 1 | Original step125000.pt | 0.377 | -0.005 |
143
+ | Step 1 from orig step 100k | Original step100000.pt | 0.285 (Step 1 HSS) | +0.004 vs 0.281 |
144
+
145
+ Step 3 from the same checkpoint reproduces to within 0.002. The E2E variance
146
+ (0.349-0.379) is dominated by torch.compile nondeterminism in Step 1.
147
+
148
+ ### All benchmarks
149
+
150
+ | Model | Input | HSS | F1 | IoU | Notes |
151
+ |-------|-------|-----|-----|-----|-------|
152
+ | Handcrafted baseline | raw views | 0.307 | 0.404 | 0.260 | |
153
+ | h256+qk+ep (submitted) | 2048 | 0.365 | 0.388 | 0.360 | HSS=0.427 on test |
154
+ | Original 3-step | 2048 | 0.373 | 0.404 | 0.363 | |
155
+ | Original 3-step | 4096 | 0.382 | 0.414 | 0.370 | Best ever |
156
+ | Step3 repro from orig S2 | 4096 | 0.384 | 0.414 | — | Near-exact repro |
157
+ | E2E repro #4 | 4096 | 0.379 | 0.409 | 0.369 | |
158
+ | Compiled repro (submission codebase) | 4096 | 0.376 | — | — | Best compiled from this exact codebase |
159
+ | E2E repro #3 | 4096 | 0.375 | 0.404 | 0.367 | |
160
+ | Deterministic E2E | 4096 | 0.372 | 0.398 | 0.368 | Bit-reproducible |
161
+
162
+ ## Code Equivalence Verification
163
+
164
+ | Test | Result |
165
+ |------|--------|
166
+ | Forward pass (same checkpoint, same input) | Bit-identical (0.00 diff) |
167
+ | Loss computation | Bit-identical (0.00 diff) |
168
+ | Gradient computation | 5e-8 max diff |
169
+ | Training from same seed | Bit-identical steps 1-44 |
170
+ | Step 3 from same checkpoint (2 runs) | HSS=0.382, 0.384 |
171
+ | Deterministic mode (2 runs) | Bit-identical (0.00 diff) |
172
+
173
+ ## Reproducibility Notes
174
+
175
+ **Default mode** (`reproduce.sh`): Uses torch.compile (~3x faster). Each run
176
+ gets different Triton kernels, causing ~1e-8 floating-point divergence at a
177
+ random step (31-45). This grows through chaotic SGD dynamics, giving HSS
178
+ variance of ~0.03 across runs. E2E reproductions land in the 0.349-0.379 range.
179
+
180
+ **Deterministic mode** (`--deterministic` flag): Disables torch.compile.
181
+ Bit-identical across runs with the same seed. HSS=0.372 (slightly lower than
182
+ compiled mode because eager-mode kernels follow a different numerical path).
183
+
184
+ **bad_samples.txt**: The shipped file has 156 entries to match original training.
185
+ (Note: `wc -l` reports 155 because the last line lacks a trailing newline.)
186
+ Two additional bad samples (`47b0e0ce19b`, `4b2d56eb3ef`) were discovered after
187
+ the original training run. They are legitimately bad (misaligned GT) but were
188
+ included in the original training data. Adding them changes the batch iteration
189
+ order and costs ~0.005 HSS in deterministic mode (0.372 -> 0.367) and ~0.04 in
190
+ compiled mode due to compounded torch.compile variance. Participants training
191
+ from scratch may wish to add these 2 entries for cleaner training data, but
192
+ should expect slightly different scores due to the changed iteration order.
193
+
194
+ The shipped `checkpoint.pt` is from the original training run (HSS=0.382).
checkpoint.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc38a61ff512948b1dc92a30129d6efdd093f507948fc5b538050c4a38bfbf6c
3
- size 106460054
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1296423a1a2e603ba55860d8ef8fa3a861764a7bbc3de96b776fca59cf5b11ab
3
+ size 106429791
configs/base.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "arch": "perceiver",
3
+ "segments": 64,
4
+ "hidden": 256,
5
+ "ff": 1024,
6
+ "num_heads": 4,
7
+ "kv_heads_cross": 2,
8
+ "kv_heads_self": 2,
9
+ "latent_tokens": 256,
10
+ "latent_layers": 7,
11
+ "decoder_layers": 3,
12
+ "cross_attn_interval": 4,
13
+ "encoder_layers": 4,
14
+ "behind_emb_dim": 8,
15
+ "dropout": 0.1,
16
+ "activation": "gelu",
17
+ "rms_norm": true,
18
+ "qk_norm": true,
19
+ "qk_norm_type": "l2",
20
+ "segment_param": "midpoint_dir_len",
21
+ "segment_conf": true,
22
+ "vote_features": true,
23
+
24
+ "adam_betas": "0.9,0.95",
25
+ "weight_decay": 0.01,
26
+ "warmup": 10000,
27
+ "varifold_weight": 0.0,
28
+ "sinkhorn_weight": 1.0,
29
+ "sinkhorn_eps": 0.1,
30
+ "sinkhorn_iters": 20,
31
+ "sinkhorn_dustbin": 0.3,
32
+ "conf_weight": 0.1,
33
+ "conf_mode": "sinkhorn",
34
+ "conf_head_wd": 0.1,
35
+
36
+ "aug_rotate": true,
37
+ "aug_flip": true,
38
+ "seed": 353
39
+ }
repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 32,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_2048_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 0,
18
+ "cooldown_steps": 0,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": false,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.0,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "5b37dfc70c392936631b59d0bab24f20e4a2b0d9",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 0.0003,
41
+ "num_heads": 4,
42
+ "out_dir": "runs/validate_155_compiled",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 2048,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 125000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/compiled_repro_hss376/20260408_173614_64c7_4670_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e27c8ae20c291676a2c0b7e080d6d00be86f251ae6bdfe3cc3ff6f27f6646b00
3
+ size 106427231
repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 64,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 0,
18
+ "cooldown_steps": 0,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": false,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.0,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "5b37dfc70c392936631b59d0bab24f20e4a2b0d9",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 3e-05,
41
+ "num_heads": 4,
42
+ "out_dir": "runs/validate_155_compiled",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "runs/validate_155_compiled/20260408_173614_64c7_4670/checkpoints/step125000.pt",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 4096,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 135000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/compiled_repro_hss376/20260408_194447_3061_6284_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89bae7e879900c128bbfe1a05e2f6d4b8430675ed95d47ea2d975b198c71cdad
3
+ size 106429599
repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 64,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 150000,
18
+ "cooldown_steps": 20000,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": false,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.1,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "5b37dfc70c392936631b59d0bab24f20e4a2b0d9",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 3e-05,
41
+ "num_heads": 4,
42
+ "out_dir": "runs/validate_155_compiled",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "runs/validate_155_compiled/20260408_194447_3061_6284/checkpoints/step135000.pt",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 4096,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 170000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/compiled_repro_hss376/20260408_201237_4177_7208_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb5779baa160aab1f55b1c698cc39da6a72cae5c9bf5456a2b4063f2342b85a3
3
+ size 106429599
repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 32,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_2048_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 0,
18
+ "cooldown_steps": 0,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": true,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.0,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 0.0003,
41
+ "num_heads": 4,
42
+ "out_dir": "/workspace/s23dr_2026_example/repro_deterministic",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 2048,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 125000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/deterministic_hss372/20260330_025738_f0c9_3400_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bd05b2b323ced2ed94c4bd382f0e07ce5fae382f19ea35cb456ff631bcd1ac0
3
+ size 106423583
repro_runs/deterministic_hss372/20260330_071030_8c95_3610_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 64,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 0,
18
+ "cooldown_steps": 0,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": true,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.0,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 3e-05,
41
+ "num_heads": 4,
42
+ "out_dir": "/workspace/s23dr_2026_example/repro_deterministic",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "/workspace/s23dr_2026_example/repro_deterministic/20260330_025738_f0c9_3400/checkpoints/step125000.pt",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 4096,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 135000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/deterministic_hss372/20260330_071030_8c95_3610_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:992dbe9c7bfb4713b72d0f444f480f4a53e10844330d426d3fe0f367cdb96441
3
+ size 106425951
repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 64,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 150000,
18
+ "cooldown_steps": 20000,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": true,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.1,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 3e-05,
41
+ "num_heads": 4,
42
+ "out_dir": "/workspace/s23dr_2026_example/repro_deterministic",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "/workspace/s23dr_2026_example/repro_deterministic/20260330_071030_8c95_3610/checkpoints/step135000.pt",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 4096,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 170000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/deterministic_hss372/20260330_073711_fdd2_8901_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd7f508eb05e42ae70efb64fd8b3ab17d036000d49589b9122d4a1a2429c35db
3
+ size 106425951
repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 32,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_2048_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 0,
18
+ "cooldown_steps": 0,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": false,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.0,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 0.0003,
41
+ "num_heads": 4,
42
+ "out_dir": "/workspace/s23dr_2026_example/repro_e2e_run4",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 2048,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 125000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/e2e_repro4_hss379/20260329_213417_ef91_6503_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8ba72f11560b8164c87cc839e3070a37e024a2a618b7820d4819b739902aa2b
3
+ size 106427231
repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 64,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 0,
18
+ "cooldown_steps": 0,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": false,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.0,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 3e-05,
41
+ "num_heads": 4,
42
+ "out_dir": "/workspace/s23dr_2026_example/repro_e2e_run4",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "/workspace/s23dr_2026_example/repro_e2e_run4/20260329_213417_ef91_6503/checkpoints/step125000.pt",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 4096,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 135000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/e2e_repro4_hss379/20260330_002648_ca92_4553_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d2f2f135f8de8b676f35e5fe2693c56c2ed060649967ff877e63385d607992f
3
+ size 106429663
repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_args.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "adam_betas": "0.9,0.95",
4
+ "arch": "perceiver",
5
+ "args_from": "configs/base.json",
6
+ "aug_drop": 0.0,
7
+ "aug_flip": true,
8
+ "aug_jitter": 0.0,
9
+ "aug_rotate": true,
10
+ "batch_size": 64,
11
+ "behind_emb_dim": 8,
12
+ "cache_dir": "hf://usm3d/s23dr-2026-sampled_4096_v2:train",
13
+ "conf_clamp_min": null,
14
+ "conf_head_wd": 0.1,
15
+ "conf_mode": "sinkhorn",
16
+ "conf_weight": 0.1,
17
+ "cooldown_start": 150000,
18
+ "cooldown_steps": 20000,
19
+ "cosine_decay": false,
20
+ "cpu": false,
21
+ "cross_attn_interval": 4,
22
+ "decoder_input_xattn": false,
23
+ "decoder_layers": 3,
24
+ "deterministic": false,
25
+ "dropout": 0.1,
26
+ "ema_decay": 0.0,
27
+ "encoder_layers": 4,
28
+ "endpoint_warmup": 0,
29
+ "endpoint_weight": 0.1,
30
+ "ff": 1024,
31
+ "git_dirty": true,
32
+ "git_sha": "465f2c6eb6ce4be5c2e52e8384961930f5f9f20a",
33
+ "hidden": 256,
34
+ "kv_heads_cross": 2,
35
+ "kv_heads_self": 2,
36
+ "latent_layers": 7,
37
+ "latent_tokens": 256,
38
+ "learnable_fourier": false,
39
+ "length_floor": 0.0,
40
+ "lr": 3e-05,
41
+ "num_heads": 4,
42
+ "out_dir": "/workspace/s23dr_2026_example/repro_e2e_run4",
43
+ "pre_encoder_layers": 0,
44
+ "qk_norm": true,
45
+ "qk_norm_type": "l2",
46
+ "resume": "/workspace/s23dr_2026_example/repro_e2e_run4/20260330_002648_ca92_4553/checkpoints/step135000.pt",
47
+ "rms_norm": true,
48
+ "seed": 353,
49
+ "segment_conf": true,
50
+ "segment_param": "midpoint_dir_len",
51
+ "segments": 64,
52
+ "seq_len": 4096,
53
+ "sinkhorn_dustbin": 0.3,
54
+ "sinkhorn_eps": 0.1,
55
+ "sinkhorn_eps_schedule": "none",
56
+ "sinkhorn_eps_start": null,
57
+ "sinkhorn_iters": 20,
58
+ "sinkhorn_weight": 1.0,
59
+ "steps": 170000,
60
+ "val_cache_dir": "",
61
+ "varifold_cross_only": false,
62
+ "varifold_weight": 0.0,
63
+ "vote_features": true,
64
+ "warmup": 10000,
65
+ "weight_decay": 0.01
66
+ }
repro_runs/e2e_repro4_hss379/20260330_005554_dec7_7390_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:161888591f9066bd2c3f42400839a46755bec5e72dc5f1245d7b3336c1a7ddc2
3
+ size 106429663
reproduce.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Reproduce the best checkpoint (HSS=0.382) from scratch.
3
+ #
4
+ # Three stages:
5
+ # 1. Train on 2048-point data (~1.5hr on 1x RTX 4090)
6
+ # 2. Finetune on 4096-point data (~15min)
7
+ # 3. Cooldown with endpoint loss (~1hr)
8
+ #
9
+ # Total: ~3hr on a single GPU (plus ~30min for compilation + data loading).
10
+ # All shared config lives in configs/base.json.
11
+ # Each step only specifies what changes.
12
+ set -e
13
+
14
+ OUT_DIR="${1:-runs}"
15
+ BASE="--args-from configs/base.json"
16
+
17
+ # ============================================================
18
+ # Step 1: Train on 2048-point data (Phase 1)
19
+ # ============================================================
20
+ echo "=== Step 1: Training on 2048 data ==="
21
+ python -m s23dr_2026_example.train $BASE \
22
+ --cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train \
23
+ --seq-len 2048 \
24
+ --lr 3e-4 \
25
+ --batch-size 32 \
26
+ --steps 125000 \
27
+ --out-dir "$OUT_DIR"
28
+
29
+ STEP1_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
30
+ echo "Step 1 complete: $STEP1_DIR"
31
+
32
+ # ============================================================
33
+ # Step 2: Finetune on 4096-point data
34
+ # ============================================================
35
+ echo "=== Step 2: Finetuning on 4096 data ==="
36
+ python -m s23dr_2026_example.train $BASE \
37
+ --cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
38
+ --resume "$STEP1_DIR/checkpoints/step125000.pt" \
39
+ --seq-len 4096 \
40
+ --lr 3e-5 \
41
+ --batch-size 64 \
42
+ --steps 135000 \
43
+ --out-dir "$OUT_DIR"
44
+
45
+ STEP2_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
46
+ echo "Step 2 complete: $STEP2_DIR"
47
+
48
+ # ============================================================
49
+ # Step 3: Cooldown with endpoint loss
50
+ # ============================================================
51
+ echo "=== Step 3: Cooldown with endpoint loss ==="
52
+ python -m s23dr_2026_example.train $BASE \
53
+ --cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
54
+ --resume "$STEP2_DIR/checkpoints/step135000.pt" \
55
+ --seq-len 4096 \
56
+ --lr 3e-5 \
57
+ --batch-size 64 \
58
+ --endpoint-weight 0.1 \
59
+ --cooldown-start 150000 \
60
+ --cooldown-steps 20000 \
61
+ --steps 170000 \
62
+ --out-dir "$OUT_DIR"
63
+
64
+ STEP3_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
65
+ echo "Step 3 complete: $STEP3_DIR"
66
+ echo ""
67
+ echo "Final checkpoint: $STEP3_DIR/checkpoints/final.pt"
68
+ echo "Copy to checkpoint.pt for submission."
reproduce_deterministic.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Reproduce the best checkpoint in deterministic mode (bit-reproducible).
3
+ #
4
+ # Same three stages as reproduce.sh, but with --deterministic:
5
+ # 1. Train on 2048-point data (~3hr on 1x RTX 4090)
6
+ # 2. Finetune on 4096-point data (~30min)
7
+ # 3. Cooldown with endpoint loss (~2hr)
8
+ #
9
+ # Total: ~5.5hr on a single GPU (no torch.compile, ~2x slower than reproduce.sh).
10
+ # Deterministic mode disables torch.compile and forces CUDA deterministic ops.
11
+ # Results are bit-identical across runs with the same seed. Expected HSS ~0.372.
12
+ set -e
13
+
14
+ OUT_DIR="${1:-runs}"
15
+ BASE="--args-from configs/base.json"
16
+
17
+ # ============================================================
18
+ # Step 1: Train on 2048-point data (Phase 1)
19
+ # ============================================================
20
+ echo "=== Step 1: Training on 2048 data (deterministic) ==="
21
+ python -m s23dr_2026_example.train $BASE \
22
+ --cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train \
23
+ --seq-len 2048 \
24
+ --lr 3e-4 \
25
+ --batch-size 32 \
26
+ --steps 125000 \
27
+ --deterministic \
28
+ --out-dir "$OUT_DIR"
29
+
30
+ STEP1_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
31
+ echo "Step 1 complete: $STEP1_DIR"
32
+
33
+ # ============================================================
34
+ # Step 2: Finetune on 4096-point data
35
+ # ============================================================
36
+ echo "=== Step 2: Finetuning on 4096 data (deterministic) ==="
37
+ python -m s23dr_2026_example.train $BASE \
38
+ --cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
39
+ --resume "$STEP1_DIR/checkpoints/step125000.pt" \
40
+ --seq-len 4096 \
41
+ --lr 3e-5 \
42
+ --batch-size 64 \
43
+ --steps 135000 \
44
+ --deterministic \
45
+ --out-dir "$OUT_DIR"
46
+
47
+ STEP2_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
48
+ echo "Step 2 complete: $STEP2_DIR"
49
+
50
+ # ============================================================
51
+ # Step 3: Cooldown with endpoint loss
52
+ # ============================================================
53
+ echo "=== Step 3: Cooldown with endpoint loss (deterministic) ==="
54
+ python -m s23dr_2026_example.train $BASE \
55
+ --cache-dir hf://usm3d/s23dr-2026-sampled_4096_v2:train \
56
+ --resume "$STEP2_DIR/checkpoints/step135000.pt" \
57
+ --seq-len 4096 \
58
+ --lr 3e-5 \
59
+ --batch-size 64 \
60
+ --endpoint-weight 0.1 \
61
+ --cooldown-start 150000 \
62
+ --cooldown-steps 20000 \
63
+ --steps 170000 \
64
+ --deterministic \
65
+ --out-dir "$OUT_DIR"
66
+
67
+ STEP3_DIR=$(ls -dt "$OUT_DIR"/*/args.json 2>/dev/null | head -1 | xargs dirname)
68
+ echo "Step 3 complete: $STEP3_DIR"
69
+ echo ""
70
+ echo "Final checkpoint: $STEP3_DIR/checkpoints/final.pt"
71
+ echo "Copy to checkpoint.pt for submission."
s23dr_2026_example/attention.py CHANGED
@@ -139,88 +139,3 @@ class FeedForward(nn.Module):
139
  def forward(self, x: torch.Tensor) -> torch.Tensor:
140
  x = self.linear1(x)
141
  return self.linear2(self.activation(x))
142
-
143
-
144
- # =============================================================================
145
- # Custom Transformer Block
146
- # =============================================================================
147
-
148
- class TransformerBlock(nn.Module):
149
- """
150
- Single transformer block combining:
151
- - multi-head SDPA (non-causal)
152
- - layernorm + residual
153
- - feed-forward MLP + residual
154
- """
155
- def __init__(
156
- self,
157
- d_model: int,
158
- num_heads: int,
159
- dim_ff: int,
160
- dropout: float = 0.0,
161
- activation: str = "gelu",
162
- kv_heads: int = None,
163
- ):
164
- super().__init__()
165
- self.norm1 = nn.LayerNorm(d_model)
166
- self.norm2 = nn.LayerNorm(d_model)
167
-
168
- self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads)
169
- self.dropout1 = nn.Dropout(dropout)
170
- self.ffn = FeedForward(d_model, dim_ff, activation=activation)
171
- self.dropout2 = nn.Dropout(dropout)
172
-
173
- def forward(
174
- self,
175
- x: torch.Tensor,
176
- memory: torch.Tensor,
177
- memory_key_padding_mask: torch.Tensor | None = None,
178
- ) -> torch.Tensor:
179
- res = x
180
- x = self.norm1(x)
181
- x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask)
182
- x = res + self.dropout1(x)
183
-
184
- res = x
185
- x = self.norm2(x)
186
- x = self.ffn(x)
187
- return res + self.dropout2(x)
188
-
189
-
190
- class TransformerDecoderSets(nn.Module):
191
- """
192
- A stack of TransformerBlock layers for set-to-set
193
- modeling without causal masks.
194
- """
195
- def __init__(
196
- self,
197
- d_model: int,
198
- num_heads: int,
199
- dim_ff: int,
200
- num_layers: int,
201
- dropout: float = 0.0,
202
- activation: str = "gelu",
203
- kv_heads: int = None,
204
- ):
205
- super().__init__()
206
- self.layers = nn.ModuleList([
207
- TransformerBlock(
208
- d_model,
209
- num_heads,
210
- dim_ff,
211
- dropout=dropout,
212
- activation=activation,
213
- kv_heads=kv_heads,
214
- )
215
- for _ in range(num_layers)
216
- ])
217
-
218
- def forward(
219
- self,
220
- tgt: torch.Tensor,
221
- memory: torch.Tensor,
222
- memory_key_padding_mask: torch.Tensor | None = None,
223
- ) -> torch.Tensor:
224
- for layer in self.layers:
225
- tgt = layer(tgt, memory, memory_key_padding_mask=memory_key_padding_mask)
226
- return tgt
 
139
  def forward(self, x: torch.Tensor) -> torch.Tensor:
140
  x = self.linear1(x)
141
  return self.linear2(self.activation(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/cache_scenes.py CHANGED
@@ -23,24 +23,9 @@ Cache format per file (.pt):
23
  """
24
  from __future__ import annotations
25
 
26
- import sys
27
- from pathlib import Path as _Path
28
- if __package__ is None or __package__ == "":
29
- _here = _Path(__file__).resolve().parent
30
- if str(_here.parent) not in sys.path:
31
- sys.path.insert(0, str(_here.parent))
32
- __package__ = _here.name
33
-
34
- import argparse
35
- import time
36
- from concurrent.futures import ProcessPoolExecutor, as_completed
37
- from pathlib import Path
38
-
39
  import numpy as np
40
- import torch
41
 
42
  from .point_fusion import (
43
- FuserConfig, build_compact_scene,
44
  GEST_ID_TO_NAME, ADE_ID_TO_NAME, NUM_GEST,
45
  )
46
 
@@ -191,183 +176,3 @@ def _compute_smart_center_scale(xyz, source, mad_k=2.5, percentile=95.0,
191
  return center.astype(np.float32), np.float32(scale)
192
 
193
 
194
- def _process_one(sample, cfg):
195
- """Process a single HF sample into a cache dict. Returns (order_id, dict) or None."""
196
- rng = np.random.RandomState() # worker-local rng
197
-
198
- n_edges = len(sample.get("wf_edges", []))
199
- if n_edges == 0 or n_edges > 64:
200
- return None
201
-
202
- scene = build_compact_scene(sample, cfg, rng=rng)
203
- if scene is None:
204
- return None
205
-
206
- gt_v = scene.get("gt_vertices")
207
- gt_e = scene.get("gt_edges")
208
- if gt_v is None or gt_e is None or len(gt_e) == 0:
209
- return None
210
-
211
- xyz = scene["xyz"]
212
- source = scene["source"]
213
- visible_src = scene["visible_src"]
214
- visible_id = scene["visible_id"]
215
- behind_id = scene["behind_gest_id"]
216
-
217
- group_id, class_id = _compute_group_and_class(
218
- visible_src, visible_id, behind_id, source
219
- )
220
-
221
- center, scale = _compute_smart_center_scale(xyz, source)
222
-
223
- order_id = sample.get("order_id", "unknown")
224
-
225
- return order_id, {
226
- "xyz": xyz.astype(np.float32),
227
- "source": source.astype(np.uint8),
228
- "group_id": group_id,
229
- "class_id": class_id,
230
- "behind_gest_id": behind_id.astype(np.int16),
231
- "visible_src": visible_src.astype(np.uint8),
232
- "visible_id": visible_id.astype(np.int16),
233
- "n_views_voted": scene["n_views_voted"],
234
- "vote_frac": scene["vote_frac"],
235
- "center": center,
236
- "scale": scale,
237
- "gt_vertices": gt_v.astype(np.float32),
238
- "gt_edges": gt_e.astype(np.int32),
239
- }
240
-
241
-
242
- def main():
243
- p = argparse.ArgumentParser(description="Cache compact scenes from HoHo22k")
244
- g = p.add_mutually_exclusive_group(required=True)
245
- g.add_argument("--data-dir", help="Local dir with shards")
246
- g.add_argument("--streaming", action="store_true", help="Stream from HuggingFace")
247
- p.add_argument("--out-dir", required=True, help="Output directory for .pt files")
248
- p.add_argument("--limit", type=int, default=0)
249
- p.add_argument("--depth-per-view", type=int, default=8000)
250
- p.add_argument("--workers", type=int, default=0,
251
- help="Parallel workers (0=sequential)")
252
- p.add_argument("--skip-existing", action="store_true",
253
- help="Skip samples whose .pt already exists in out-dir")
254
- p.add_argument("--shard-start", type=int, default=0,
255
- help="First shard index (for parallel launches)")
256
- p.add_argument("--shard-stride", type=int, default=1,
257
- help="Stride between shards (e.g. 8 means take every 8th shard)")
258
- args = p.parse_args()
259
-
260
- out_dir = Path(args.out_dir)
261
- out_dir.mkdir(parents=True, exist_ok=True)
262
- existing_ids = set(p.stem for p in out_dir.glob("*.pt")) if args.skip_existing else set()
263
-
264
- # Load dataset
265
- from datasets import load_dataset
266
- if args.streaming:
267
- ds = load_dataset(
268
- "usm3d/hoho22k_2026_trainval",
269
- streaming=True, trust_remote_code=True, split="train",
270
- )
271
- else:
272
- data_root = Path(args.data_dir).resolve()
273
- tars = []
274
- for candidate in [data_root / "data" / "train", data_root / "train", data_root]:
275
- if candidate.exists():
276
- tars = sorted(str(p) for p in candidate.glob("*.tar"))
277
- if tars:
278
- break
279
- loader = None
280
- for c in [data_root / "hoho22k_2026_trainval.py"]:
281
- if c.exists():
282
- loader = c
283
- break
284
- if loader is None:
285
- found = list(data_root.rglob("hoho22k_2026_trainval.py"))
286
- loader = found[0] if found else None
287
- if loader is None:
288
- raise FileNotFoundError("Cannot find loader script")
289
- # Shard-level parallelism: each process handles a slice of tars
290
- if args.shard_stride > 1:
291
- tars = tars[args.shard_start::args.shard_stride]
292
- print(f"Shard slice: start={args.shard_start} stride={args.shard_stride} -> {len(tars)} shards")
293
- ds = load_dataset(str(loader), data_files={"train": tars},
294
- streaming=True, trust_remote_code=True, split="train")
295
-
296
- cfg = FuserConfig(depth_points_per_view=args.depth_per_view)
297
-
298
- saved = 0
299
- skipped = 0
300
- t_start = time.perf_counter()
301
-
302
- if args.workers > 0:
303
- # Parallel: collect samples into batches, process in worker pool
304
- # Note: HF streaming datasets can't be shared across workers, so we
305
- # iterate in the main thread and dispatch processing to workers.
306
- with ProcessPoolExecutor(max_workers=args.workers) as pool:
307
- futures = {}
308
- for i, sample in enumerate(ds):
309
- if args.limit > 0 and i >= args.limit:
310
- break
311
- oid = sample.get("order_id", "unknown")
312
- if oid in existing_ids:
313
- skipped += 1
314
- continue
315
- future = pool.submit(_process_one, sample, cfg)
316
- futures[future] = i
317
-
318
- # Drain completed futures to bound memory
319
- if len(futures) >= args.workers * 4:
320
- done = [f for f in futures if f.done()]
321
- for f in done:
322
- result = f.result()
323
- del futures[f]
324
- if result is None:
325
- skipped += 1
326
- continue
327
- order_id, data = result
328
- torch.save(data, out_dir / f"{order_id}.pt")
329
- saved += 1
330
- if saved % 50 == 0:
331
- elapsed = time.perf_counter() - t_start
332
- print(f"Saved {saved} (skipped {skipped}) "
333
- f"[{saved / elapsed:.1f} samples/s]")
334
-
335
- # Drain remaining
336
- for f in as_completed(futures):
337
- result = f.result()
338
- if result is None:
339
- skipped += 1
340
- continue
341
- order_id, data = result
342
- torch.save(data, out_dir / f"{order_id}.pt")
343
- saved += 1
344
- else:
345
- # Sequential
346
- for i, sample in enumerate(ds):
347
- if args.limit > 0 and i >= args.limit:
348
- break
349
- oid = sample.get("order_id", "unknown")
350
- if oid in existing_ids:
351
- skipped += 1
352
- continue
353
-
354
- result = _process_one(sample, cfg)
355
- if result is None:
356
- skipped += 1
357
- continue
358
- order_id, data = result
359
- torch.save(data, out_dir / f"{order_id}.pt")
360
- saved += 1
361
-
362
- if saved % 50 == 0:
363
- elapsed = time.perf_counter() - t_start
364
- print(f"Saved {saved} (skipped {skipped}) "
365
- f"[{saved / elapsed:.1f} samples/s]")
366
-
367
- elapsed = time.perf_counter() - t_start
368
- print(f"Done. Saved {saved}, skipped {skipped} in {elapsed:.0f}s "
369
- f"({saved / elapsed:.1f} samples/s)")
370
-
371
-
372
- if __name__ == "__main__":
373
- main()
 
23
  """
24
  from __future__ import annotations
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  import numpy as np
 
27
 
28
  from .point_fusion import (
 
29
  GEST_ID_TO_NAME, ADE_ID_TO_NAME, NUM_GEST,
30
  )
31
 
 
176
  return center.astype(np.float32), np.float32(scale)
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/color_mappings.py CHANGED
@@ -181,29 +181,3 @@ ade20k_color_mapping = {
181
  'clock': (102, 255, 0),
182
  'flag': (92, 0, 255),
183
  }
184
-
185
-
186
- EDGE_CLASSES = {'cornice_return': 0,
187
- 'cornice_strip': 1,
188
- 'eave': 2,
189
- 'flashing': 3,
190
- 'hip': 4,
191
- 'rake': 5,
192
- 'ridge': 6,
193
- 'step_flashing': 7,
194
- 'transition_line': 8,
195
- 'valley': 9}
196
- EDGE_CLASSES_BY_ID = {v: k for k, v in EDGE_CLASSES.items()}
197
-
198
- edge_color_mapping = {
199
- 'cornice_return': (215, 62, 138),
200
- 'cornice_strip': (235, 88, 48),
201
- 'eave': (54, 243, 63),
202
- "flashing": (162, 162, 32),
203
- 'hip': (8, 89, 52),
204
- 'rake': (13, 94, 47),
205
- 'ridge': (214, 251, 248),
206
- "step_flashing": (169, 255, 219),
207
- 'transition_line': (200,0,50),
208
- 'valley': (85, 27, 65),
209
- }
 
181
  'clock': (102, 255, 0),
182
  'flag': (92, 0, 255),
183
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/data.py CHANGED
@@ -1,6 +1,7 @@
1
  """Data loading for pre-sampled HF datasets.
2
 
3
- Expects pre-sampled npz blobs with xyz_norm [2048, 3] (not full PCD).
 
4
  Use make_sampled_cache.py to produce these from full point clouds.
5
  """
6
  from __future__ import annotations
@@ -12,7 +13,7 @@ import torch
12
 
13
  from .tokenizer import EdgeDepthSequenceConfig
14
 
15
- # Default token budget (must match make_sampled_cache.py)
16
  SEQ_LEN = 2048
17
  COLMAP_POINTS = 1536
18
  DEPTH_POINTS = 512
@@ -130,9 +131,6 @@ def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False)
130
  result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
131
  if "vote_frac" in d:
132
  result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
133
- if "gt_edge_classes" in d:
134
- result["gt_edge_classes"] = torch.as_tensor(
135
- np.asarray(d["gt_edge_classes"], dtype=np.int64), dtype=torch.long)
136
  return result
137
 
138
 
@@ -161,14 +159,6 @@ def collate(batch):
161
  f"Field '{field}' present in some batch samples but missing in "
162
  f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
163
  out[field] = torch.stack([d[field] for d in batch])
164
- # gt_edge_classes: variable length per sample (like gt_segments), keep as list
165
- if any("gt_edge_classes" in d for d in batch):
166
- missing = [i for i, d in enumerate(batch) if "gt_edge_classes" not in d]
167
- if missing:
168
- raise KeyError(
169
- f"Field 'gt_edge_classes' present in some batch samples but missing in "
170
- f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
171
- out["gt_edge_classes"] = [d["gt_edge_classes"] for d in batch]
172
  return out
173
 
174
 
 
1
  """Data loading for pre-sampled HF datasets.
2
 
3
+ Expects pre-sampled npz blobs with xyz_norm (not full PCD).
4
+ Supports both 2048-point and 4096-point datasets.
5
  Use make_sampled_cache.py to produce these from full point clouds.
6
  """
7
  from __future__ import annotations
 
13
 
14
  from .tokenizer import EdgeDepthSequenceConfig
15
 
16
+ # Default token budget (for 2048-point datasets; 4096 uses 3072/1024)
17
  SEQ_LEN = 2048
18
  COLMAP_POINTS = 1536
19
  DEPTH_POINTS = 512
 
131
  result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
132
  if "vote_frac" in d:
133
  result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
 
 
 
134
  return result
135
 
136
 
 
159
  f"Field '{field}' present in some batch samples but missing in "
160
  f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
161
  out[field] = torch.stack([d[field] for d in batch])
 
 
 
 
 
 
 
 
162
  return out
163
 
164
 
s23dr_2026_example/losses.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
 
6
  from .varifold import varifold_loss_batch
7
  from .sinkhorn import batched_sinkhorn_loss
8
- from .soft_hss_loss import batched_sinkhorn_vertex_f1, batched_soft_hss_v2
9
 
10
  # Varifold config
11
  VARIANT = "simpson3"
@@ -18,21 +17,12 @@ VARIFOLD_CROSS_ONLY = False # Set to True to drop self-energy (avoids O(S^2) bl
18
  SINKHORN_EPS = 0.05
19
  SINKHORN_ITERS = 10
20
 
21
- # Distance thresholds in meters (divided by per-scene scale at runtime)
22
- VERTEX_THRESH_M = 0.5 # vertex match threshold (mirrors real HSS)
23
- TUBE_RADIUS_M = 0.5 # tube IoU radius (mirrors real HSS)
24
-
25
  # Sinkhorn dustbin cost: controls the OT "not matching" penalty.
26
  # Like tau, this is an OT behavior parameter, NOT a physical distance.
27
  # Must be comparable to typical matching costs in normalized space (~0.1).
28
  # Do NOT divide by scale.
29
  SINKHORN_DUSTBIN = 0.1
30
 
31
- # Sigmoid temperature: controls gradient smoothness, NOT a distance threshold.
32
- # Must stay large enough in normalized space to provide useful gradients.
33
- # Do NOT divide by scale (unlike the thresholds above).
34
- SIGMOID_TAU = 0.05
35
-
36
  MAX_GT = 64 # fixed pad size for compile-friendly shapes
37
 
38
  # Precomputed constants (created once on first call)
@@ -65,7 +55,7 @@ def pad_gt_fixed(gt_list, device, dtype):
65
 
66
 
67
  def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
68
- sigmas, alphas, varifold_w, vertex_f1_w):
69
  """Pure tensor loss -- no Python control flow, no boolean indexing."""
70
  has_gt = (gt_lengths > 0).float()
71
 
@@ -78,77 +68,36 @@ def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
78
  v = loss_batch / gt_lengths.clamp(min=1.0)
79
  v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
80
 
81
- thresh = VERTEX_THRESH_M / scales
82
- f1 = batched_sinkhorn_vertex_f1(
83
- pred_segments, gt_pad, gt_mask, thresh=thresh, tau=SIGMOID_TAU)
84
- f1 = (f1 * has_gt).sum() / has_gt.sum().clamp(min=1.0)
85
-
86
- total = varifold_w * v + vertex_f1_w * f1
87
- return total, v, f1
88
 
89
 
90
  # Will be replaced with compiled version on CUDA
91
  _loss_fn = _loss_inner
92
 
93
 
94
- def _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales):
95
- """Auxiliary BCE loss: train conf to predict whether each segment matches GT.
96
-
97
- Computes per-segment min-distance to GT, creates soft match target via
98
- sigmoid thresholding, and returns BCE(sigmoid(conf), target).
99
- """
100
- B, S = pred_segments.shape[:2]
101
- # Decoupled cost: midpoint + direction + length (same as sinkhorn)
102
- p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
103
- g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
104
- mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
105
- mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
106
- d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
107
- len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
108
- len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
109
- dir_p = half_p / len_p
110
- dir_g = half_g / len_g
111
- cos_angle = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
112
- d_dir = 1.0 - cos_angle.abs()
113
- d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
114
- cost = d_mid + d_dir + d_len # [B, S, M]
115
-
116
- # Mask invalid GT with high cost
117
- cost = torch.where(gt_mask.unsqueeze(1), cost, cost.new_tensor(1e6))
118
- min_dist = cost.min(dim=2).values # [B, S]
119
-
120
- # Soft target: sigmoid((thresh - dist) / tau), in normalized space
121
- thresh = VERTEX_THRESH_M / scales # [B]
122
- target = torch.sigmoid((thresh[:, None] - min_dist) / SIGMOID_TAU)
123
-
124
- return torch.nn.functional.binary_cross_entropy_with_logits(
125
- conf_logits, target.detach(), reduction="mean")
126
-
127
-
128
  def compute_loss(pred_segments, gt_list, scales, device,
129
- varifold_w, sinkhorn_w, vertex_f1_w=0.0, soft_hss_w=0.0,
130
  endpoint_w=0.0,
131
- conf_logits=None, conf_weight=0.0, conf_mode="match",
132
  sinkhorn_eps=None, sinkhorn_iters=None,
133
  sinkhorn_dustbin=None, conf_clamp_min=None):
134
  """Combined loss with fixed-size GT padding.
135
 
136
- conf_mode: "match" = BCE matching supervision, "sinkhorn" = conf-weighted sinkhorn.
137
  """
138
  if conf_logits is not None and conf_clamp_min is not None:
139
  conf_logits = conf_logits.clamp(min=conf_clamp_min)
140
  gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
141
  c = _get_loss_constants(device, pred_segments.dtype)
142
 
143
- total, v, f1 = _loss_fn(
144
  pred_segments, gt_pad, gt_mask, gt_lengths, scales,
145
- c["sigmas"], c["alphas"], varifold_w, vertex_f1_w)
146
 
147
  terms = {}
148
  if varifold_w > 0:
149
  terms["varifold"] = v.detach()
150
- if vertex_f1_w > 0:
151
- terms["vertex_f1"] = f1.detach()
152
 
153
  if sinkhorn_w > 0:
154
  has_gt = (gt_lengths > 0).float()
@@ -171,28 +120,8 @@ def compute_loss(pred_segments, gt_list, scales, device,
171
  total = total + sinkhorn_w * s
172
  terms["sinkhorn"] = s.detach()
173
 
174
- if soft_hss_w > 0:
175
- has_gt = (gt_lengths > 0).float()
176
- vert_thresh = VERTEX_THRESH_M / scales
177
- edge_thresh = TUBE_RADIUS_M / scales
178
- hss_loss = batched_soft_hss_v2(
179
- pred_segments, gt_pad, gt_mask,
180
- vert_thresh=vert_thresh, edge_thresh=edge_thresh, tau=SIGMOID_TAU)
181
- hs = (hss_loss * has_gt).sum() / has_gt.sum().clamp(min=1.0)
182
- total = total + soft_hss_w * hs
183
- terms["soft_hss"] = hs.detach()
184
-
185
  if conf_logits is not None and conf_weight > 0:
186
- if conf_mode == "match":
187
- # Explicit BCE supervision from nearest-GT distances
188
- cl = _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales)
189
- total = total + conf_weight * cl
190
- terms["conf"] = cl.detach()
191
- elif conf_mode in ("sinkhorn", "sinkhorn_detach"):
192
- # Conf trained through sinkhorn transport gradients (via pred_mass).
193
- # sinkhorn_detach: pred_mass uses detached conf, so OT can't push conf negative.
194
- # Add count regularizer to prevent all-zero conf collapse.
195
- # Normalized by S so magnitude doesn't depend on segment count.
196
  conf_w = torch.sigmoid(conf_logits)
197
  S = conf_logits.shape[1]
198
  gt_counts = gt_mask.sum(dim=1).float()
@@ -200,29 +129,6 @@ def compute_loss(pred_segments, gt_list, scales, device,
200
  reg = (((conf_sum - gt_counts) / S) ** 2).mean()
201
  total = total + conf_weight * reg
202
  terms["conf_reg"] = reg.detach()
203
- elif conf_mode == "varifold":
204
- # Conf-weighted varifold: weight each pred segment's contribution
205
- # by sigmoid(conf). Low-conf segments contribute less to the loss.
206
- # Needs regularizer to prevent all-zero conf collapse.
207
- has_gt = (gt_lengths > 0).float()
208
- conf_w = torch.sigmoid(conf_logits) # [B, S]
209
- sigmas_eff = c["sigmas"] / scales[:, None]
210
- vf_conf = varifold_loss_batch(
211
- pred_segments, gt_pad, gt_mask=gt_mask,
212
- variant=VARIANT, sigmas=sigmas_eff, alpha=c["alphas"],
213
- len_pow=LEN_POW, pred_weights=conf_w,
214
- )
215
- vc = (vf_conf / gt_lengths.clamp(min=1.0))
216
- vc = (vc * has_gt).sum() / has_gt.sum().clamp(min=1.0)
217
- # Regularizer: penalize total conf being far from n_gt
218
- # Normalized by S so magnitude doesn't depend on segment count
219
- S = conf_logits.shape[1]
220
- gt_counts = gt_mask.sum(dim=1).float() # [B]
221
- conf_sum = conf_w.sum(dim=1) # [B]
222
- reg = (((conf_sum - gt_counts) / S) ** 2).mean()
223
- total = total + conf_weight * vc + 0.01 * reg
224
- terms["conf_vf"] = vc.detach()
225
- terms["conf_reg"] = reg.detach()
226
  else:
227
  raise ValueError(f"Unknown conf_mode: {conf_mode}")
228
 
@@ -234,14 +140,12 @@ def compute_loss(pred_segments, gt_list, scales, device,
234
  B, S = pred_segments.shape[:2]
235
  M = gt_pad.shape[1]
236
 
237
- # Compute hard assignment via sinkhorn (detached matching is not trained)
238
  with torch.no_grad():
239
  pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
240
  sink_loss_for_assign = batched_sinkhorn_loss(
241
  pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
242
  pred_mass=pred_mass_ep)
243
- # Re-run sinkhorn to get transport matrix for assignment
244
- # (reuse the cost computation from batched_sinkhorn_loss internals)
245
  p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
246
  g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
247
  mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
 
5
 
6
  from .varifold import varifold_loss_batch
7
  from .sinkhorn import batched_sinkhorn_loss
 
8
 
9
  # Varifold config
10
  VARIANT = "simpson3"
 
17
  SINKHORN_EPS = 0.05
18
  SINKHORN_ITERS = 10
19
 
 
 
 
 
20
  # Sinkhorn dustbin cost: controls the OT "not matching" penalty.
21
  # Like tau, this is an OT behavior parameter, NOT a physical distance.
22
  # Must be comparable to typical matching costs in normalized space (~0.1).
23
  # Do NOT divide by scale.
24
  SINKHORN_DUSTBIN = 0.1
25
 
 
 
 
 
 
26
  MAX_GT = 64 # fixed pad size for compile-friendly shapes
27
 
28
  # Precomputed constants (created once on first call)
 
55
 
56
 
57
  def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
58
+ sigmas, alphas, varifold_w):
59
  """Pure tensor loss -- no Python control flow, no boolean indexing."""
60
  has_gt = (gt_lengths > 0).float()
61
 
 
68
  v = loss_batch / gt_lengths.clamp(min=1.0)
69
  v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
70
 
71
+ total = varifold_w * v
72
+ return total, v
 
 
 
 
 
73
 
74
 
75
  # Will be replaced with compiled version on CUDA
76
  _loss_fn = _loss_inner
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def compute_loss(pred_segments, gt_list, scales, device,
80
+ varifold_w, sinkhorn_w,
81
  endpoint_w=0.0,
82
+ conf_logits=None, conf_weight=0.0, conf_mode="sinkhorn",
83
  sinkhorn_eps=None, sinkhorn_iters=None,
84
  sinkhorn_dustbin=None, conf_clamp_min=None):
85
  """Combined loss with fixed-size GT padding.
86
 
87
+ conf_mode: "sinkhorn" = conf-weighted sinkhorn, "sinkhorn_detach" = detached conf.
88
  """
89
  if conf_logits is not None and conf_clamp_min is not None:
90
  conf_logits = conf_logits.clamp(min=conf_clamp_min)
91
  gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
92
  c = _get_loss_constants(device, pred_segments.dtype)
93
 
94
+ total, v = _loss_fn(
95
  pred_segments, gt_pad, gt_mask, gt_lengths, scales,
96
+ c["sigmas"], c["alphas"], varifold_w)
97
 
98
  terms = {}
99
  if varifold_w > 0:
100
  terms["varifold"] = v.detach()
 
 
101
 
102
  if sinkhorn_w > 0:
103
  has_gt = (gt_lengths > 0).float()
 
120
  total = total + sinkhorn_w * s
121
  terms["sinkhorn"] = s.detach()
122
 
 
 
 
 
 
 
 
 
 
 
 
123
  if conf_logits is not None and conf_weight > 0:
124
+ if conf_mode in ("sinkhorn", "sinkhorn_detach"):
 
 
 
 
 
 
 
 
 
125
  conf_w = torch.sigmoid(conf_logits)
126
  S = conf_logits.shape[1]
127
  gt_counts = gt_mask.sum(dim=1).float()
 
129
  reg = (((conf_sum - gt_counts) / S) ** 2).mean()
130
  total = total + conf_weight * reg
131
  terms["conf_reg"] = reg.detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
  raise ValueError(f"Unknown conf_mode: {conf_mode}")
134
 
 
140
  B, S = pred_segments.shape[:2]
141
  M = gt_pad.shape[1]
142
 
143
+ # Compute hard assignment via sinkhorn (detached -- matching is not trained)
144
  with torch.no_grad():
145
  pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
146
  sink_loss_for_assign = batched_sinkhorn_loss(
147
  pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
148
  pred_mass=pred_mass_ep)
 
 
149
  p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
150
  g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
151
  mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
s23dr_2026_example/make_sampled_cache.py CHANGED
@@ -24,21 +24,7 @@ the same points. Fine for now; better augmentation can be added later.
24
  """
25
  from __future__ import annotations
26
 
27
- import sys
28
- from pathlib import Path as _Path
29
- if __package__ is None or __package__ == "":
30
- _here = _Path(__file__).resolve().parent
31
- if str(_here.parent) not in sys.path:
32
- sys.path.insert(0, str(_here.parent))
33
- __package__ = _here.name
34
-
35
- import argparse
36
- import io
37
- import time
38
- from pathlib import Path
39
-
40
  import numpy as np
41
- import torch
42
 
43
 
44
  # Priority sampling (same logic as train.py)
@@ -87,174 +73,3 @@ def _priority_sample(source, group_id, seq_len, colmap_quota, depth_quota):
87
  return indices[:seq_len], mask
88
 
89
 
90
- def process_sample(xyz, source, group_id, class_id, vis_src, vis_id,
91
- center, scale, gt_v, gt_e, behind=None,
92
- n_views_voted=None, vote_frac=None,
93
- gt_edge_classes=None,
94
- seq_len=2048, colmap_q=1536, depth_q=512):
95
- """Sample and normalize one scene. Returns dict of numpy arrays."""
96
- indices, mask = _priority_sample(source, group_id, seq_len, colmap_q, depth_q)
97
- xyz_norm = ((xyz[indices] - center) / scale).astype(np.float32)
98
- gt_seg = np.stack([gt_v[gt_e[:, 0]], gt_v[gt_e[:, 1]]], axis=1)
99
- gt_seg_norm = ((gt_seg - center) / scale).astype(np.float32)
100
-
101
- result = {
102
- "xyz_norm": xyz_norm,
103
- "class_id": class_id[indices].astype(np.uint8),
104
- "source": source[indices].astype(np.uint8),
105
- "mask": mask,
106
- "gt_segments": gt_seg_norm,
107
- "scale": np.float32(scale),
108
- "center": center.astype(np.float32),
109
- "gt_vertices": gt_v.astype(np.float32),
110
- "gt_edges": gt_e.astype(np.int32),
111
- "visible_src": vis_src[indices].astype(np.uint8),
112
- "visible_id": vis_id[indices].astype(np.int16),
113
- }
114
- if behind is not None:
115
- result["behind"] = behind[indices].astype(np.int16)
116
- if n_views_voted is not None:
117
- result["n_views_voted"] = n_views_voted[indices].astype(np.uint8)
118
- if vote_frac is not None:
119
- result["vote_frac"] = vote_frac[indices].astype(np.float32)
120
- if gt_edge_classes is not None:
121
- if len(gt_edge_classes) != len(gt_e):
122
- raise ValueError(
123
- f"gt_edge_classes length {len(gt_edge_classes)} != "
124
- f"gt_edges length {len(gt_e)}")
125
- result["gt_edge_classes"] = gt_edge_classes.astype(np.int64)
126
- return result
127
-
128
-
129
- def _load_edge_classes(path):
130
- """Load edge classifications lookup from npz file."""
131
- if path is None:
132
- return None
133
- path = Path(path)
134
- if not path.exists():
135
- raise FileNotFoundError(f"Edge classifications file not found: {path}")
136
- data = np.load(str(path), allow_pickle=False)
137
- lookup = {k: data[k] for k in data.files}
138
- print(f"Loaded edge classifications for {len(lookup)} orders from {path}")
139
- return lookup
140
-
141
-
142
- def main():
143
- p = argparse.ArgumentParser()
144
- g = p.add_mutually_exclusive_group(required=True)
145
- g.add_argument("--in-dir", help="Local directory of .pt files")
146
- g.add_argument("--hf-repo", help="HuggingFace dataset repo (e.g. usm3d/s23dr-2026-cached_full_pcd)")
147
- p.add_argument("--split", default="train", help="HF dataset split")
148
- p.add_argument("--out-dir", required=True)
149
- p.add_argument("--edge-classes", default=None,
150
- help="Path to edge_classifications.npz from extract_edge_classes.py")
151
- p.add_argument("--seq-len", type=int, default=2048)
152
- p.add_argument("--colmap-quota", type=int, default=1536)
153
- p.add_argument("--depth-quota", type=int, default=512)
154
- p.add_argument("--seed", type=int, default=7)
155
- args = p.parse_args()
156
-
157
- out_dir = Path(args.out_dir)
158
- out_dir.mkdir(parents=True, exist_ok=True)
159
- np.random.seed(args.seed)
160
-
161
- edge_cls_lookup = _load_edge_classes(args.edge_classes)
162
- n_edge_matched, n_edge_missing = 0, 0
163
-
164
- t_start = time.perf_counter()
165
- done = 0
166
-
167
- if args.in_dir:
168
- # Local .pt files
169
- files = sorted(Path(args.in_dir).glob("*.pt"))
170
- print(f"Converting {len(files)} local .pt files...")
171
- for f in files:
172
- out_f = out_dir / (f.stem + ".npz")
173
- if out_f.exists():
174
- done += 1
175
- continue
176
- d = torch.load(f, weights_only=False)
177
- behind = np.asarray(d["behind_gest_id"], np.int16) if "behind_gest_id" in d else None
178
- n_vv = np.asarray(d["n_views_voted"], np.uint8) if "n_views_voted" in d else None
179
- vf = np.asarray(d["vote_frac"], np.float32) if "vote_frac" in d else None
180
- gt_ec = None
181
- if edge_cls_lookup is not None:
182
- order_id = f.stem
183
- if order_id in edge_cls_lookup:
184
- gt_ec = edge_cls_lookup[order_id]
185
- n_edge_matched += 1
186
- else:
187
- n_edge_missing += 1
188
- result = process_sample(
189
- np.asarray(d["xyz"], np.float32),
190
- np.asarray(d["source"], np.uint8),
191
- np.asarray(d["group_id"], np.int8),
192
- np.asarray(d["class_id"], np.uint8),
193
- np.asarray(d["visible_src"], np.uint8),
194
- np.asarray(d["visible_id"], np.int16),
195
- np.asarray(d["center"], np.float32),
196
- float(d["scale"]),
197
- np.asarray(d["gt_vertices"], np.float32),
198
- np.asarray(d["gt_edges"], np.int32),
199
- behind=behind, n_views_voted=n_vv, vote_frac=vf,
200
- gt_edge_classes=gt_ec,
201
- seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
202
- )
203
- np.savez(out_f, **result)
204
- done += 1
205
- if done % 2000 == 0:
206
- print(f" {done}/{len(files)} [{done/(time.perf_counter()-t_start):.0f}/s]")
207
- else:
208
- # HF dataset
209
- from datasets import load_dataset
210
- print(f"Loading {args.hf_repo} split={args.split}...")
211
- ds = load_dataset(args.hf_repo, split=args.split)
212
- print(f"Converting {len(ds)} samples...")
213
- for i, sample in enumerate(ds):
214
- order_id = sample["order_id"]
215
- out_f = out_dir / f"{order_id}.npz"
216
- if out_f.exists():
217
- done += 1
218
- continue
219
- arrays = np.load(io.BytesIO(sample["data"]))
220
- behind = arrays["behind_gest_id"] if "behind_gest_id" in arrays else None
221
- n_vv = arrays["n_views_voted"] if "n_views_voted" in arrays else None
222
- vf = arrays["vote_frac"] if "vote_frac" in arrays else None
223
- gt_ec = None
224
- if edge_cls_lookup is not None:
225
- if order_id in edge_cls_lookup:
226
- gt_ec = edge_cls_lookup[order_id]
227
- n_edge_matched += 1
228
- else:
229
- n_edge_missing += 1
230
- result = process_sample(
231
- arrays["xyz"], arrays["source"], arrays["group_id"],
232
- arrays["class_id"], arrays["visible_src"], arrays["visible_id"],
233
- arrays["center"], float(arrays["scale"]),
234
- arrays["gt_vertices"], arrays["gt_edges"],
235
- behind=behind, n_views_voted=n_vv, vote_frac=vf,
236
- gt_edge_classes=gt_ec,
237
- seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
238
- )
239
- np.savez(out_f, **result)
240
- done += 1
241
- if done % 2000 == 0:
242
- print(f" {done}/{len(ds)} [{done/(time.perf_counter()-t_start):.0f}/s]")
243
-
244
- elapsed = time.perf_counter() - t_start
245
- print(f"Done: {done} files in {elapsed:.0f}s ({done/max(1,elapsed):.0f}/s)")
246
-
247
- if edge_cls_lookup is not None:
248
- print(f"Edge classifications: {n_edge_matched} matched, {n_edge_missing} missing")
249
-
250
- # Report sizes
251
- import os
252
- npz_files = list(out_dir.glob("*.npz"))
253
- if npz_files:
254
- sizes = [os.path.getsize(f) for f in npz_files[:100]]
255
- print(f"Avg file size: {np.mean(sizes)/1024:.0f}KB")
256
- print(f"Est total: {np.mean(sizes)*len(npz_files)/1e9:.1f}GB")
257
-
258
-
259
- if __name__ == "__main__":
260
- main()
 
24
  """
25
  from __future__ import annotations
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  import numpy as np
 
28
 
29
 
30
  # Priority sampling (same logic as train.py)
 
73
  return indices[:seq_len], mask
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/model.py CHANGED
@@ -239,7 +239,6 @@ class TokenTransformerSegments(nn.Module):
239
  self.segments = segments
240
  self.out_vertices = segments * 2
241
  self.segment_param = segment_param
242
- self.length_floor = length_floor
243
  self.decoder_input_xattn = decoder_input_xattn
244
  norm_class = norm_class or nn.LayerNorm
245
 
@@ -428,167 +427,7 @@ class SelfAttentionEncoderLayer(nn.Module):
428
 
429
 
430
  # ---------------------------------------------------------------------------
431
- # Vanilla transformer: self-attention encoder + segment query decoder
432
- # ---------------------------------------------------------------------------
433
-
434
- class TransformerSegments(nn.Module):
435
- """Standard transformer encoder + cross-attention segment decoder.
436
-
437
- Architecture:
438
- Input tokens [B, T, D]
439
- |
440
- v
441
- input_proj: Linear -> GELU -> Linear -> Norm => [B, T, hidden]
442
- |
443
- v
444
- N SelfAttentionEncoderLayers (self-attn over all T tokens)
445
- |
446
- v
447
- Segment decoder (same as Perceiver version):
448
- M SegmentDecoderLayers (queries cross-attend to encoded tokens)
449
- |
450
- v
451
- segment_head -> endpoints [B, S, 2, 3] (midpoint_halfvec or midpoint_dir_len)
452
- """
453
-
454
- def __init__(
455
- self,
456
- segments: int = 32,
457
- in_dim: int = 128,
458
- hidden: int = 128,
459
- num_heads: int = 4,
460
- kv_heads_cross: int | None = 2,
461
- kv_heads_self: int | None = 0,
462
- dim_feedforward: int = 256,
463
- dropout: float = 0.01,
464
- encoder_layers: int = 4,
465
- decoder_layers: int = 2,
466
- norm_class=None,
467
- activation: str = "gelu",
468
- segment_conf: bool = False,
469
- segment_param: str = "midpoint_halfvec",
470
- length_floor: float = 0.0,
471
- decoder_input_xattn: bool = False,
472
- qk_norm: bool = False,
473
- qk_norm_type: str = "l2",
474
- ):
475
- super().__init__()
476
- self.segments = segments
477
- self.out_vertices = segments * 2
478
- self.segment_param = segment_param
479
- self.length_floor = length_floor
480
- norm_class = norm_class or nn.LayerNorm
481
-
482
- if kv_heads_cross is not None and kv_heads_cross <= 0:
483
- kv_heads_cross = None
484
- if kv_heads_self is not None and kv_heads_self <= 0:
485
- kv_heads_self = None
486
-
487
- # -- Input projection --
488
- self.input_proj = nn.Sequential(
489
- nn.Linear(in_dim, dim_feedforward),
490
- nn.GELU(),
491
- nn.Linear(dim_feedforward, hidden),
492
- norm_class(hidden),
493
- )
494
-
495
- # -- Self-attention encoder --
496
- self.encoder_layers = nn.ModuleList([
497
- SelfAttentionEncoderLayer(
498
- d_model=hidden,
499
- num_heads=num_heads,
500
- dim_ff=dim_feedforward,
501
- dropout=dropout,
502
- activation=activation,
503
- kv_heads=kv_heads_self,
504
- norm_class=norm_class,
505
- qk_norm=qk_norm, qk_norm_type=qk_norm_type,
506
- )
507
- for _ in range(encoder_layers)
508
- ])
509
-
510
- # -- Segment decoder (same structure as Perceiver version) --
511
- # Note: for transformer arch, decoder_input_xattn is ignored because
512
- # the decoder already cross-attends to the full encoded token sequence.
513
- self.query_embed = nn.Embedding(segments, hidden)
514
- self.decoder_layers = nn.ModuleList([
515
- SegmentDecoderLayer(
516
- d_model=hidden,
517
- num_heads=num_heads,
518
- dim_ff=dim_feedforward,
519
- dropout=dropout,
520
- activation=activation,
521
- kv_heads_cross=kv_heads_cross,
522
- kv_heads_self=kv_heads_self,
523
- norm_class=norm_class,
524
- qk_norm=qk_norm, qk_norm_type=qk_norm_type,
525
- )
526
- for _ in range(decoder_layers)
527
- ])
528
-
529
- # -- Output head (shared logic with Perceiver version) --
530
- if segment_param == "midpoint_dir_len":
531
- self.segment_head = nn.Linear(hidden, 7) # mid(3) + dir(3) + len(1)
532
- else:
533
- self.segment_head = nn.Linear(hidden, 6) # mid(3) + half(3)
534
- self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3))
535
-
536
- nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3)
537
- if self.segment_head.bias is not None:
538
- nn.init.zeros_(self.segment_head.bias)
539
- if segment_param == "midpoint_dir_len":
540
- # sigmoid(-2.2) ~ 0.1 default length in normalized space (~3m)
541
- self.segment_head.bias.data[6] = -2.2
542
- nn.init.normal_(self.query_offsets, mean=0.0, std=0.05)
543
-
544
- self.segment_conf = segment_conf
545
- if segment_conf:
546
- self.conf_head = nn.Linear(hidden, 1)
547
- nn.init.zeros_(self.conf_head.bias)
548
-
549
- def forward(
550
- self,
551
- tokens: torch.Tensor,
552
- mask: torch.Tensor | None = None,
553
- ) -> dict[str, torch.Tensor | list]:
554
- B = tokens.shape[0]
555
-
556
- src = self.input_proj(tokens)
557
- pad_mask = ~mask.bool() if mask is not None else None
558
-
559
- # Encode: self-attention over all tokens
560
- for layer in self.encoder_layers:
561
- src = layer(src, key_padding_mask=pad_mask)
562
-
563
- # Decode: segment queries cross-attend to encoded tokens
564
- queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)
565
- for layer in self.decoder_layers:
566
- queries = layer(queries, src)
567
-
568
- # Predict segments -> endpoints
569
- if self.segment_param == "midpoint_dir_len":
570
- raw = self.segment_head(queries) # [B, S, 7]
571
- mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0)
572
- direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1)
573
- length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1
574
- half = direction * length * 0.5
575
- else:
576
- raw = self.segment_head(queries).view(B, self.segments, 2, 3)
577
- raw = raw + self.query_offsets.unsqueeze(0)
578
- mid, half = raw[:, :, 0], raw[:, :, 1]
579
- seg_params = torch.stack([mid - half, mid + half], dim=2)
580
-
581
- vertices = seg_params.reshape(B, self.out_vertices, 3)
582
- edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)]
583
-
584
- out = {"vertices": vertices, "segments": seg_params, "edges": edges}
585
- if self.segment_conf:
586
- out["conf"] = self.conf_head(queries).squeeze(-1)
587
- return out
588
-
589
-
590
- # ---------------------------------------------------------------------------
591
- # End-to-end model: tokenizer embeddings + transformer/perceiver
592
  # ---------------------------------------------------------------------------
593
 
594
  class EdgeDepthSegmentsModel(nn.Module):
@@ -648,25 +487,9 @@ class EdgeDepthSegmentsModel(nn.Module):
648
  )
649
 
650
  if arch == "transformer":
651
- self.segmenter = TransformerSegments(
652
- segments=segments,
653
- in_dim=self.tokenizer.out_dim,
654
- hidden=hidden,
655
- num_heads=num_heads,
656
- kv_heads_cross=kv_heads_cross,
657
- kv_heads_self=kv_heads_self,
658
- dim_feedforward=dim_feedforward,
659
- dropout=dropout,
660
- encoder_layers=encoder_layers,
661
- decoder_layers=decoder_layers,
662
- norm_class=norm_class,
663
- activation=activation,
664
- segment_conf=segment_conf,
665
- segment_param=segment_param,
666
- length_floor=length_floor,
667
- decoder_input_xattn=decoder_input_xattn,
668
- qk_norm=qk_norm, qk_norm_type=qk_norm_type,
669
- )
670
  else:
671
  self.segmenter = TokenTransformerSegments(
672
  segments=segments,
 
239
  self.segments = segments
240
  self.out_vertices = segments * 2
241
  self.segment_param = segment_param
 
242
  self.decoder_input_xattn = decoder_input_xattn
243
  norm_class = norm_class or nn.LayerNorm
244
 
 
427
 
428
 
429
  # ---------------------------------------------------------------------------
430
+ # End-to-end model: tokenizer embeddings + perceiver
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  # ---------------------------------------------------------------------------
432
 
433
  class EdgeDepthSegmentsModel(nn.Module):
 
487
  )
488
 
489
  if arch == "transformer":
490
+ raise ValueError(
491
+ "arch='transformer' is no longer supported. "
492
+ "TransformerSegments has been removed; use arch='perceiver'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  else:
494
  self.segmenter = TokenTransformerSegments(
495
  segments=segments,
s23dr_2026_example/sinkhorn.py CHANGED
@@ -10,26 +10,6 @@ to get useful gradients early and precise matching late.
10
  import torch
11
 
12
 
13
- def segment_pair_cost(pred_segments: torch.Tensor, gt_segments: torch.Tensor) -> torch.Tensor:
14
- """Cost between pred and GT segments: midpoint + direction + length (decoupled).
15
- pred_segments: [N, 2, 3], gt_segments: [M, 2, 3] -> [N, M]
16
- """
17
- p0, p1 = pred_segments[:, 0], pred_segments[:, 1]
18
- g0, g1 = gt_segments[:, 0], gt_segments[:, 1]
19
- mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
20
- mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
21
- d_mid = torch.cdist(mid_p, mid_g)
22
- len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
23
- len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
24
- dir_p = half_p / len_p
25
- dir_g = half_g / len_g
26
- cos_angle = (dir_p[:, None, :] * dir_g[None, :, :]).sum(dim=-1)
27
- d_dir = 1.0 - cos_angle.abs()
28
- d_len = (len_p[:, None, :] - len_g[None, :, :]).squeeze(-1).abs()
29
- return d_mid + d_dir + d_len
30
-
31
-
32
-
33
  def batched_sinkhorn_loss(
34
  pred_segments: torch.Tensor,
35
  gt_pad: torch.Tensor,
@@ -144,38 +124,3 @@ def batched_sinkhorn_loss(
144
 
145
  transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
146
  return (transport * cost_pad).sum(dim=(1, 2)) # [B]
147
-
148
-
149
- # Keep the per-sample version for compatibility
150
- def sinkhorn_segment_loss(
151
- pred_segments: torch.Tensor,
152
- gt_segments: torch.Tensor,
153
- eps: float,
154
- iters: int,
155
- dustbin_cost: float,
156
- pred_mass: torch.Tensor | None = None,
157
- ) -> torch.Tensor:
158
- if pred_segments.numel() == 0 or gt_segments.numel() == 0:
159
- return pred_segments.new_tensor(dustbin_cost)
160
- cost = segment_pair_cost(pred_segments, gt_segments)
161
- n, m = cost.shape
162
- if n == 0 or m == 0:
163
- return cost.new_tensor(dustbin_cost)
164
- cost_pad = torch.full((n + 1, m + 1), dustbin_cost, device=cost.device, dtype=cost.dtype)
165
- cost_pad[:n, :m] = cost
166
- cost_pad[-1, -1] = 0.0
167
- denom = float(n + m)
168
- a = torch.full((n + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
169
- b = torch.full((m + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
170
- a[-1] = m / denom
171
- b[-1] = n / denom
172
- log_a = torch.log(a + 1e-9)
173
- log_b = torch.log(b + 1e-9)
174
- log_k = -cost_pad / eps
175
- log_u = torch.zeros_like(a)
176
- log_v = torch.zeros_like(b)
177
- for _ in range(iters):
178
- log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
179
- log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
180
- transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
181
- return torch.sum(transport * cost_pad)
 
10
  import torch
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def batched_sinkhorn_loss(
14
  pred_segments: torch.Tensor,
15
  gt_pad: torch.Tensor,
 
124
 
125
  transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
126
  return (transport * cost_pad).sum(dim=(1, 2)) # [B]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/soft_hss_loss.py DELETED
@@ -1,507 +0,0 @@
1
- import torch
2
-
3
-
4
- def _softmin(values: torch.Tensor, dim: int, tau: float) -> torch.Tensor:
5
- tau_t = torch.as_tensor(tau, device=values.device, dtype=values.dtype).clamp_min(1e-8)
6
- return -tau_t * torch.logsumexp(-values / tau_t, dim=dim)
7
-
8
-
9
- def point_segment_distance_squared(
10
- points: torch.Tensor,
11
- seg_a: torch.Tensor,
12
- seg_b: torch.Tensor,
13
- eps: float = 1e-9,
14
- ) -> torch.Tensor:
15
- """
16
- points: (P,3)
17
- seg_a/seg_b: (S,3)
18
- returns dist2: (P,S)
19
- """
20
- ab = seg_b - seg_a # (S,3)
21
- ab2 = (ab * ab).sum(dim=-1).clamp_min(eps) # (S,)
22
- ap = points[:, None, :] - seg_a[None, :, :] # (P,S,3)
23
- t = (ap * ab[None, :, :]).sum(dim=-1) / ab2[None, :] # (P,S)
24
- t = t.clamp(0.0, 1.0)
25
- closest = seg_a[None, :, :] + t[:, :, None] * ab[None, :, :]
26
- diff = points[:, None, :] - closest
27
- return (diff * diff).sum(dim=-1)
28
-
29
-
30
- def distance_to_segments(
31
- points: torch.Tensor,
32
- segments: torch.Tensor,
33
- eps: float = 1e-9,
34
- ) -> torch.Tensor:
35
- """
36
- points: (P,3)
37
- segments: (S,2,3)
38
- returns min distance: (P,)
39
- """
40
- a = segments[:, 0]
41
- b = segments[:, 1]
42
- dist2 = point_segment_distance_squared(points, a, b, eps=eps)
43
- return torch.sqrt(dist2.min(dim=1).values + eps)
44
-
45
-
46
- def soft_vertex_f1(
47
- pred_vertices: torch.Tensor,
48
- gt_vertices: torch.Tensor,
49
- thresh: float,
50
- tau: float = 0.05,
51
- softmin_tau: float = 0.05,
52
- eps: float = 1e-8,
53
- ) -> torch.Tensor:
54
- """
55
- Soft surrogate for the Hungarian-thresholded corner F1 used by HSS.
56
-
57
- Uses (soft) nearest-neighbor distances and a sigmoid threshold.
58
- """
59
- if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
60
- return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
61
-
62
- pred = pred_vertices
63
- gt = gt_vertices
64
-
65
- diff = pred[:, None, :] - gt[None, :, :]
66
- dist = torch.sqrt((diff * diff).sum(dim=-1) + eps) # (P,G)
67
-
68
- d_pred = _softmin(dist, dim=1, tau=softmin_tau) # (P,)
69
- d_gt = _softmin(dist, dim=0, tau=softmin_tau) # (G,)
70
-
71
- tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
72
- thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
73
- p_match = torch.sigmoid((thresh_t - d_pred) / tau_t).mean()
74
- r_match = torch.sigmoid((thresh_t - d_gt) / tau_t).mean()
75
- return 2.0 * p_match * r_match / (p_match + r_match + eps)
76
-
77
-
78
- def soft_tube_iou_mc(
79
- pred_segments: torch.Tensor,
80
- gt_segments: torch.Tensor,
81
- radius: float,
82
- n_samples: int = 4096,
83
- tau: float = 0.05,
84
- seed: int = 0,
85
- eps: float = 1e-8,
86
- ) -> torch.Tensor:
87
- """
88
- Soft surrogate for volumetric tube IoU (edge_thresh in HSS).
89
-
90
- Samples points uniformly in a padded bbox around {pred,gt} endpoints.
91
- Occupancy is sigmoid((radius - d(x, segments))/tau).
92
- IoU is approximated by mean(min(occ_p, occ_g)) / mean(max(occ_p, occ_g)).
93
- """
94
- if pred_segments.numel() == 0 or gt_segments.numel() == 0:
95
- return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
96
-
97
- pts_all = torch.cat([pred_segments.reshape(-1, 3), gt_segments.reshape(-1, 3)], dim=0)
98
- pad = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
99
- lo = pts_all.min(dim=0).values - pad
100
- hi = pts_all.max(dim=0).values + pad
101
-
102
- gen = torch.Generator(device=pts_all.device)
103
- gen.manual_seed(int(seed))
104
- u = torch.rand((int(n_samples), 3), generator=gen, device=pts_all.device, dtype=pts_all.dtype)
105
- x = lo[None, :] + u * (hi - lo)[None, :]
106
-
107
- d_p = distance_to_segments(x, pred_segments, eps=eps)
108
- d_g = distance_to_segments(x, gt_segments, eps=eps)
109
-
110
- tau_t = torch.as_tensor(tau, device=pts_all.device, dtype=pts_all.dtype).clamp_min(1e-8)
111
- rad_t = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
112
- occ_p = torch.sigmoid((rad_t - d_p) / tau_t)
113
- occ_g = torch.sigmoid((rad_t - d_g) / tau_t)
114
-
115
- inter = torch.minimum(occ_p, occ_g).mean()
116
- union = torch.maximum(occ_p, occ_g).mean().clamp_min(eps)
117
- return inter / union
118
-
119
-
120
- def soft_hss(
121
- pred_segments: torch.Tensor,
122
- gt_segments: torch.Tensor,
123
- gt_vertices: torch.Tensor,
124
- vert_thresh: float = 0.5,
125
- edge_thresh: float = 0.5,
126
- tau: float = 0.05,
127
- softmin_tau: float = 0.05,
128
- n_samples: int = 4096,
129
- seed: int = 0,
130
- eps: float = 1e-8,
131
- ):
132
- """
133
- Returns (soft_hss, soft_f1, soft_iou), all scalars in [0,1] (approximately).
134
- """
135
- pred_vertices = pred_segments.reshape(-1, 3)
136
- f1 = soft_vertex_f1(pred_vertices, gt_vertices, thresh=vert_thresh, tau=tau, softmin_tau=softmin_tau, eps=eps)
137
- iou = soft_tube_iou_mc(
138
- pred_segments,
139
- gt_segments,
140
- radius=edge_thresh,
141
- n_samples=n_samples,
142
- tau=tau,
143
- seed=seed,
144
- eps=eps,
145
- )
146
- denom = (f1 + iou).clamp_min(eps)
147
- hss = 2.0 * f1 * iou / denom
148
- return hss, f1, iou
149
-
150
-
151
- # ---------------------------------------------------------------------------
152
- # Improved: Sinkhorn-matched vertex F1
153
- # ---------------------------------------------------------------------------
154
- #
155
- # The original soft_vertex_f1 uses independent softmin nearest-neighbor
156
- # distances, which allows multiple predicted vertices to claim the same GT
157
- # vertex. This inflates precision and fails to penalize duplicate vertices --
158
- # the exact failure mode that requires merge_vertices post-processing.
159
- #
160
- # This version uses Sinkhorn optimal transport to find a soft one-to-one
161
- # assignment between predicted and GT vertices, then computes precision and
162
- # recall from the matched distances. This is a better surrogate for the
163
- # Hungarian matching used by the real HSS metric.
164
-
165
-
166
- def sinkhorn_vertex_f1(
167
- pred_vertices: torch.Tensor,
168
- gt_vertices: torch.Tensor,
169
- thresh: float = 0.5,
170
- tau: float = 0.05,
171
- eps_sinkhorn: float = 0.05,
172
- iters: int = 20,
173
- eps: float = 1e-8,
174
- ) -> torch.Tensor:
175
- """Soft vertex F1 using Sinkhorn matching (better aligned with real HSS).
176
-
177
- Instead of independent nearest-neighbor distances (which allow double-
178
- claiming), this uses optimal transport to find a soft one-to-one assignment
179
- between predicted and GT vertices.
180
-
181
- Returns a differentiable scalar in [0, 1].
182
- """
183
- if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
184
- return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
185
-
186
- P = pred_vertices.shape[0]
187
- G = gt_vertices.shape[0]
188
-
189
- # Pairwise distance matrix (P, G)
190
- dist = torch.cdist(pred_vertices, gt_vertices)
191
-
192
- # Sinkhorn with dustbin: (P+1) x (G+1)
193
- # Dustbin cost = thresh (unmatched vertices are "at threshold distance")
194
- dustbin = thresh
195
- cost_pad = torch.full((P + 1, G + 1), dustbin, device=dist.device, dtype=dist.dtype)
196
- cost_pad[:P, :G] = dist
197
- cost_pad[-1, -1] = 0.0
198
-
199
- # Uniform masses with dustbin slack
200
- denom = float(P + G)
201
- a = torch.full((P + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
202
- b = torch.full((G + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
203
- a[-1] = G / denom # pred dustbin absorbs unmatched GT
204
- b[-1] = P / denom # GT dustbin absorbs unmatched pred
205
-
206
- # Log-domain Sinkhorn
207
- log_a = torch.log(a + 1e-9)
208
- log_b = torch.log(b + 1e-9)
209
- log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
210
- log_u = torch.zeros_like(a)
211
- log_v = torch.zeros_like(b)
212
- for _ in range(iters):
213
- log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
214
- log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
215
-
216
- # Transport plan (P+1, G+1)
217
- transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
218
-
219
- # Extract the non-dustbin transport (P, G) -- these are the soft assignments
220
- T = transport[:P, :G]
221
-
222
- # For each predicted vertex, its matched distance is the transport-weighted
223
- # average distance to GT vertices
224
- # Normalize rows to sum to 1 (how much of this pred is matched vs dustbin)
225
- row_sums = T.sum(dim=1).clamp_min(eps)
226
- matched_dist_pred = (T * dist).sum(dim=1) / row_sums # (P,)
227
- match_weight_pred = row_sums * denom # how much of this pred is matched (0-1 ish)
228
-
229
- # Same for GT vertices (column perspective)
230
- col_sums = T.sum(dim=0).clamp_min(eps)
231
- matched_dist_gt = (T * dist).sum(dim=0) / col_sums # (G,)
232
- match_weight_gt = col_sums * denom
233
-
234
- # Soft precision: fraction of pred vertices that are matched AND within threshold
235
- tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
236
- thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
237
-
238
- prec_per = match_weight_pred * torch.sigmoid((thresh_t - matched_dist_pred) / tau_t)
239
- precision = prec_per.mean()
240
-
241
- # Soft recall: fraction of GT vertices that are matched AND within threshold
242
- rec_per = match_weight_gt * torch.sigmoid((thresh_t - matched_dist_gt) / tau_t)
243
- recall = rec_per.mean()
244
-
245
- return 2.0 * precision * recall / (precision + recall + eps)
246
-
247
-
248
- # ---------------------------------------------------------------------------
249
- # Improved: Segment-sampled tube IoU
250
- # ---------------------------------------------------------------------------
251
- #
252
- # The original soft_tube_iou_mc samples random points in the bounding box,
253
- # wasting most samples in empty space. This version samples along the segments
254
- # themselves, concentrating gradient signal where it matters.
255
-
256
-
257
- def _sample_along_segments(segments: torch.Tensor, n_per_seg: int = 64) -> torch.Tensor:
258
- """Sample n_per_seg points uniformly along each segment.
259
-
260
- segments: (S, 2, 3)
261
- returns: (S * n_per_seg, 3)
262
- """
263
- t = torch.linspace(0, 1, n_per_seg, device=segments.device, dtype=segments.dtype)
264
- # (S, 1, 3) + (1, N, 1) * (S, 1, 3) -> (S, N, 3)
265
- a = segments[:, 0:1, :]
266
- b = segments[:, 1:2, :]
267
- pts = a + t[None, :, None] * (b - a)
268
- return pts.reshape(-1, 3)
269
-
270
-
271
- def segment_sampled_tube_iou(
272
- pred_segments: torch.Tensor,
273
- gt_segments: torch.Tensor,
274
- radius: float = 0.5,
275
- n_per_seg: int = 64,
276
- tau: float = 0.05,
277
- eps: float = 1e-8,
278
- ) -> torch.Tensor:
279
- """Soft tube IoU by sampling along segments instead of in the bounding box.
280
-
281
- Samples points along predicted and GT segments, then checks what fraction
282
- of each set falls within radius of the other. More sample-efficient than
283
- bbox Monte Carlo and gives better gradients.
284
-
285
- Returns a differentiable scalar in [0, 1].
286
- """
287
- if pred_segments.numel() == 0 or gt_segments.numel() == 0:
288
- return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
289
-
290
- pred_pts = _sample_along_segments(pred_segments, n_per_seg)
291
- gt_pts = _sample_along_segments(gt_segments, n_per_seg)
292
-
293
- tau_t = torch.as_tensor(tau, device=pred_pts.device, dtype=pred_pts.dtype).clamp_min(1e-8)
294
- rad_t = torch.as_tensor(radius, device=pred_pts.device, dtype=pred_pts.dtype)
295
-
296
- # Precision: fraction of pred points within radius of any GT segment
297
- d_pred = distance_to_segments(pred_pts, gt_segments, eps=eps)
298
- prec = torch.sigmoid((rad_t - d_pred) / tau_t).mean()
299
-
300
- # Recall: fraction of GT points within radius of any pred segment
301
- d_gt = distance_to_segments(gt_pts, pred_segments, eps=eps)
302
- rec = torch.sigmoid((rad_t - d_gt) / tau_t).mean()
303
-
304
- # Soft IoU from precision and recall:
305
- # IoU = intersection/union = (P*R) / (P + R - P*R) for occupancy overlap
306
- return prec * rec / (prec + rec - prec * rec + eps)
307
-
308
-
309
- def soft_hss_v2(
310
- pred_segments: torch.Tensor,
311
- gt_segments: torch.Tensor,
312
- gt_vertices: torch.Tensor,
313
- vert_thresh: float = 0.5,
314
- edge_thresh: float = 0.5,
315
- tau: float = 0.05,
316
- sinkhorn_eps: float = 0.05,
317
- sinkhorn_iters: int = 20,
318
- n_per_seg: int = 64,
319
- eps: float = 1e-8,
320
- ):
321
- """Improved soft HSS using Sinkhorn vertex matching + segment-sampled IoU.
322
-
323
- Returns (soft_hss, soft_f1, soft_iou).
324
- """
325
- pred_vertices = pred_segments.reshape(-1, 3)
326
- f1 = sinkhorn_vertex_f1(
327
- pred_vertices, gt_vertices,
328
- thresh=vert_thresh, tau=tau,
329
- eps_sinkhorn=sinkhorn_eps, iters=sinkhorn_iters, eps=eps,
330
- )
331
- iou = segment_sampled_tube_iou(
332
- pred_segments, gt_segments,
333
- radius=edge_thresh, n_per_seg=n_per_seg, tau=tau, eps=eps,
334
- )
335
- denom = (f1 + iou).clamp_min(eps)
336
- hss = 2.0 * f1 * iou / denom
337
- return hss, f1, iou
338
-
339
-
340
-
341
- # ---------------------------------------------------------------------------
342
- # Batched versions for training speed
343
- # ---------------------------------------------------------------------------
344
-
345
-
346
- def batched_sinkhorn_vertex_f1(
347
- pred_segments: torch.Tensor,
348
- gt_pad: torch.Tensor,
349
- gt_mask: torch.Tensor,
350
- thresh: float | torch.Tensor = 0.5,
351
- tau: float | torch.Tensor = 0.05,
352
- eps_sinkhorn: float = 0.05,
353
- iters: int = 10,
354
- eps: float = 1e-8,
355
- ) -> torch.Tensor:
356
- """Batched Sinkhorn vertex F1 loss.
357
-
358
- Args:
359
- pred_segments: [B, S, 2, 3] predicted segments
360
- gt_pad: [B, M, 2, 3] padded GT segments
361
- gt_mask: [B, M] bool mask (True = valid GT segment)
362
- thresh: distance threshold for a vertex match (scalar or [B])
363
- tau: sigmoid temperature (scalar or [B])
364
- Returns:
365
- [B] per-sample (1 - F1) loss
366
- """
367
- B, S = pred_segments.shape[:2]
368
- M = gt_pad.shape[1]
369
- P = S * 2 # pred vertices (both endpoints)
370
-
371
- # Allow per-sample thresh and tau ([B] tensors or scalars)
372
- thresh_t = torch.as_tensor(thresh, device=pred_segments.device, dtype=pred_segments.dtype)
373
- if thresh_t.dim() == 0:
374
- thresh_t = thresh_t.expand(B)
375
- tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
376
- if tau_t.dim() == 0:
377
- tau_t = tau_t.expand(B)
378
- tau_t = tau_t.clamp_min(1e-8)
379
-
380
- pred_verts = pred_segments.reshape(B, P, 3)
381
- gt_verts = gt_pad.reshape(B, M * 2, 3) # will mask invalid ones
382
-
383
- # Build GT vertex mask: each valid segment contributes 2 vertices
384
- gt_vert_mask = gt_mask.unsqueeze(2).expand(B, M, 2).reshape(B, M * 2)
385
- G = M * 2
386
-
387
- # Pairwise distances [B, P, G]
388
- dist = torch.linalg.norm(
389
- pred_verts.unsqueeze(2) - gt_verts.unsqueeze(1), dim=-1)
390
-
391
- # Mask invalid GT with high distance
392
- dist = torch.where(gt_vert_mask.unsqueeze(1), dist, thresh_t[:, None, None] * 10.0)
393
-
394
- # Sinkhorn matching: [B, P+1, G+1]
395
- cost_pad = thresh_t[:, None, None].expand(B, P + 1, G + 1).clone()
396
- cost_pad[:, :P, :G] = dist
397
- cost_pad[:, -1, -1] = 0.0
398
-
399
- gt_counts = gt_vert_mask.sum(dim=1).float() # [B]
400
- n = float(P)
401
- denom = n + gt_counts # [B]
402
-
403
- a = (1.0 / denom).unsqueeze(1).expand(B, P + 1).clone()
404
- a[:, -1] = gt_counts / denom
405
- b = (1.0 / denom).unsqueeze(1).expand(B, G + 1).clone()
406
- b[:, -1] = n / denom
407
- b[:, :G] = b[:, :G] * gt_vert_mask.float()
408
-
409
- log_a = torch.log(a + 1e-9)
410
- log_b = torch.log(b + 1e-9)
411
- log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
412
- log_u = torch.zeros_like(a)
413
- log_v = torch.zeros_like(b)
414
-
415
- for _ in range(iters):
416
- log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
417
- log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
418
-
419
- transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
420
- T = transport[:, :P, :G] # [B, P, G]
421
-
422
- # Matched distances
423
- row_sums = T.sum(dim=2).clamp_min(eps)
424
- matched_d_pred = (T * dist).sum(dim=2) / row_sums # [B, P]
425
- w_pred = row_sums * denom.unsqueeze(1)
426
-
427
- col_sums = T.sum(dim=1).clamp_min(eps)
428
- matched_d_gt = (T * dist).sum(dim=1) / col_sums # [B, G]
429
- w_gt = col_sums * denom.unsqueeze(1)
430
-
431
- precision = (w_pred * torch.sigmoid((thresh_t[:, None] - matched_d_pred) / tau_t[:, None])).mean(dim=1)
432
- recall_raw = w_gt * torch.sigmoid((thresh_t[:, None] - matched_d_gt) / tau_t[:, None])
433
- # Mask invalid GT vertices in recall
434
- recall = (recall_raw * gt_vert_mask.float()).sum(dim=1) / gt_counts.clamp_min(1.0)
435
-
436
- f1 = 2.0 * precision * recall / (precision + recall + eps)
437
- return 1.0 - f1 # return loss (1 - F1)
438
-
439
-
440
-
441
- def batched_segment_sampled_iou(
442
- pred_segments: torch.Tensor,
443
- gt_pad: torch.Tensor,
444
- gt_mask: torch.Tensor,
445
- radius: float | torch.Tensor = 0.5,
446
- n_per_seg: int = 32,
447
- tau: float | torch.Tensor = 0.05,
448
- eps: float = 1e-8,
449
- ) -> torch.Tensor:
450
- """Batched segment-sampled tube IoU loss.
451
-
452
- Returns [B] per-sample (1 - IoU) loss.
453
- """
454
- B, S = pred_segments.shape[:2]
455
- M = gt_pad.shape[1]
456
-
457
- # Allow per-sample radius and tau ([B] tensors or scalars)
458
- rad_t = torch.as_tensor(radius, device=pred_segments.device, dtype=pred_segments.dtype)
459
- if rad_t.dim() == 0:
460
- rad_t = rad_t.expand(B)
461
- tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
462
- if tau_t.dim() == 0:
463
- tau_t = tau_t.expand(B)
464
- tau_t = tau_t.clamp_min(1e-8)
465
-
466
- # Sample points along segments
467
- t = torch.linspace(0, 1, n_per_seg, device=pred_segments.device, dtype=pred_segments.dtype)
468
-
469
- # Pred points: [B, S*n_per_seg, 3]
470
- pa = pred_segments[:, :, 0:1, :] # [B, S, 1, 3]
471
- pb = pred_segments[:, :, 1:2, :]
472
- pred_pts = (pa + t[None, None, :, None] * (pb - pa)).reshape(B, S * n_per_seg, 3)
473
-
474
- # GT points: [B, M*n_per_seg, 3]
475
- ga = gt_pad[:, :, 0:1, :]
476
- gb = gt_pad[:, :, 1:2, :]
477
- gt_pts = (ga + t[None, None, :, None] * (gb - ga)).reshape(B, M * n_per_seg, 3)
478
-
479
- # For each pred point, min distance to any GT segment endpoint samples
480
- d_pred_to_gt = torch.cdist(pred_pts, gt_pts) # [B, S*n, M*n]
481
- d_pred = d_pred_to_gt.min(dim=2).values # [B, S*n]
482
- prec = torch.sigmoid((rad_t[:, None] - d_pred) / tau_t[:, None]).mean(dim=1) # [B]
483
-
484
- d_gt_to_pred = d_pred_to_gt.min(dim=1).values # [B, M*n]
485
- # Mask invalid GT points
486
- gt_pt_mask = gt_mask.unsqueeze(2).expand(B, M, n_per_seg).reshape(B, M * n_per_seg)
487
- rec_raw = torch.sigmoid((rad_t[:, None] - d_gt_to_pred) / tau_t[:, None])
488
- rec = (rec_raw * gt_pt_mask.float()).sum(dim=1) / gt_pt_mask.float().sum(dim=1).clamp_min(1.0)
489
-
490
- iou = prec * rec / (prec + rec - prec * rec + eps)
491
- return 1.0 - iou # return loss
492
-
493
-
494
- def batched_soft_hss_v2(pred_segments, gt_pad, gt_mask,
495
- vert_thresh=0.5, edge_thresh=0.5, tau=0.05,
496
- sinkhorn_iters=10, n_per_seg=32):
497
- """Batched soft HSS loss. Returns [B] per-sample (1 - HSS)."""
498
- f1_loss = batched_sinkhorn_vertex_f1(
499
- pred_segments, gt_pad, gt_mask,
500
- thresh=vert_thresh, tau=tau, iters=sinkhorn_iters)
501
- iou_loss = batched_segment_sampled_iou(
502
- pred_segments, gt_pad, gt_mask,
503
- radius=edge_thresh, n_per_seg=n_per_seg, tau=tau)
504
- f1 = 1.0 - f1_loss
505
- iou = 1.0 - iou_loss
506
- hss = 2.0 * f1 * iou / (f1 + iou + 1e-8)
507
- return 1.0 - hss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/train.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Training script for S23DR 2026.
4
+
5
+ Usage:
6
+ python -m s23dr_2026_example.train --cache-dir hf://usm3d/s23dr-2026-sampled_2048_v2:train --steps 80000 --aug-rotate
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import sys
11
+ from pathlib import Path as _Path
12
+ if __package__ is None or __package__ == "":
13
+ _here = _Path(__file__).resolve().parent
14
+ if str(_here.parent) not in sys.path:
15
+ sys.path.insert(0, str(_here.parent))
16
+ __package__ = _here.name
17
+
18
+ import argparse
19
+ import gc
20
+ import json
21
+ import math
22
+ import subprocess
23
+ import time
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from .tokenizer import EdgeDepthSequenceConfig
30
+ from .model import EdgeDepthSegmentsModel
31
+ from .data import build_loader, build_tokens
32
+ from .losses import compute_loss, _loss_inner
33
+
34
+ # Re-export for eval scripts
35
+ from .data import HFCachedDataset, collate as _collate # noqa: F401
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Main
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def main():
43
+ p = argparse.ArgumentParser(description="S23DR 2026 training")
44
+ p.add_argument("--cache-dir", default=None, help="HF dataset path (hf://repo:split)")
45
+ p.add_argument("--val-cache-dir", default="", help="Separate cache for validation")
46
+ p.add_argument("--seq-len", type=int, default=2048,
47
+ help="Input sequence length (2048 or 4096, must match dataset)")
48
+ p.add_argument("--arch", choices=["perceiver", "transformer"], default="perceiver",
49
+ help="perceiver=latent bottleneck, transformer=full self-attention encoder")
50
+ p.add_argument("--segments", type=int, default=32)
51
+ p.add_argument("--hidden", type=int, default=128)
52
+ p.add_argument("--ff", type=int, default=512)
53
+ p.add_argument("--latent-tokens", type=int, default=128)
54
+ p.add_argument("--latent-layers", type=int, default=7)
55
+ p.add_argument("--encoder-layers", type=int, default=4,
56
+ help="Encoder layers (transformer arch only)")
57
+ p.add_argument("--pre-encoder-layers", type=int, default=0,
58
+ help="Self-attn layers on full token sequence before perceiver bottleneck")
59
+ p.add_argument("--decoder-layers", type=int, default=3)
60
+ p.add_argument("--decoder-input-xattn", action="store_true",
61
+ help="Add cross-attention from segment queries to input tokens in each decoder layer")
62
+ p.add_argument("--qk-norm", action="store_true",
63
+ help="Normalize Q and K per-head with learned temperature (stabilizes wide models)")
64
+ p.add_argument("--qk-norm-type", choices=["l2", "rms"], default="l2",
65
+ help="QK-norm type: l2 (unit sphere) or rms (RMSNorm, preserves magnitudes)")
66
+ p.add_argument("--learnable-fourier", action="store_true",
67
+ help="Make Fourier positional encoding learnable (vs fixed random)")
68
+ p.add_argument("--num-heads", type=int, default=4, help="Attention heads")
69
+ p.add_argument("--kv-heads-cross", type=int, default=2,
70
+ help="KV heads for cross-attention (GQA; 0 = standard MHA)")
71
+ p.add_argument("--kv-heads-self", type=int, default=2,
72
+ help="KV heads for self-attention (GQA; 0 = standard MHA)")
73
+ p.add_argument("--cross-attn-interval", type=int, default=4,
74
+ help="Perceiver cross-attention frequency (every N latent layers)")
75
+ p.add_argument("--dropout", type=float, default=0.1)
76
+ p.add_argument("--weight-decay", type=float, default=0.01, help="AdamW weight decay")
77
+ p.add_argument("--steps", type=int, default=5000)
78
+ p.add_argument("--batch-size", type=int, default=32)
79
+ p.add_argument("--lr", type=float, default=3e-4)
80
+ p.add_argument("--adam-betas", default="0.9,0.95", help="AdamW beta1,beta2")
81
+ p.add_argument("--warmup", type=int, default=200, help="LR warmup steps")
82
+ p.add_argument("--cosine-decay", action="store_true",
83
+ help="Cosine decay LR after warmup (to lr*0.01 at end)")
84
+ p.add_argument("--cooldown-start", type=int, default=0,
85
+ help="Step to begin linear cooldown to lr*0.01 (0=disabled, constant LR after warmup)")
86
+ p.add_argument("--cooldown-steps", type=int, default=0,
87
+ help="Number of steps for linear cooldown (0=no cooldown)")
88
+ p.add_argument("--seed", type=int, default=7)
89
+ p.add_argument("--deterministic", action="store_true",
90
+ help="Force deterministic mode (disables torch.compile, slower but bit-reproducible)")
91
+ p.add_argument("--varifold-weight", type=float, default=0.0)
92
+ p.add_argument("--varifold-cross-only", action="store_true",
93
+ help="Drop varifold self-energy (avoids O(S^2) spike, sinkhorn handles repulsion)")
94
+ p.add_argument("--sinkhorn-weight", type=float, default=1.0)
95
+ p.add_argument("--sinkhorn-eps", type=float, default=0.1,
96
+ help="Sinkhorn regularization (larger = softer matching, stronger gradients)")
97
+ p.add_argument("--sinkhorn-eps-start", type=float, default=None,
98
+ help="Starting eps for epsilon annealing (anneals to --sinkhorn-eps). None=no annealing.")
99
+ p.add_argument("--sinkhorn-eps-schedule", choices=["linear", "sqrt", "none"], default="none",
100
+ help="Eps annealing schedule: linear, sqrt, or none (default: no annealing)")
101
+ p.add_argument("--sinkhorn-iters", type=int, default=20,
102
+ help="Sinkhorn iterations")
103
+ p.add_argument("--sinkhorn-dustbin", type=float, default=0.3,
104
+ help="Sinkhorn dustbin cost in normalized space")
105
+ p.add_argument("--endpoint-weight", type=float, default=0.0,
106
+ help="Weight for endpoint distance loss (sinkhorn-matched, symmetric)")
107
+ p.add_argument("--endpoint-warmup", type=int, default=0,
108
+ help="Steps to linearly warm up endpoint weight from 0 (0=instant)")
109
+ p.add_argument("--aug-rotate", action="store_true")
110
+ p.add_argument("--aug-jitter", type=float, default=0.0,
111
+ help="Point position jitter std in normalized space (0=disabled, try 0.005)")
112
+ p.add_argument("--aug-drop", type=float, default=0.0,
113
+ help="Fraction of points to randomly drop (0=disabled, try 0.1)")
114
+ p.add_argument("--aug-flip", action="store_true",
115
+ help="Random mirror along X axis (50%% chance)")
116
+ p.add_argument("--rms-norm", action="store_true", default=True,
117
+ help="Use RMSNorm (default). Use --no-rms-norm for LayerNorm")
118
+ p.add_argument("--no-rms-norm", dest="rms_norm", action="store_false")
119
+ p.add_argument("--activation", default="gelu", help="FFN activation: gelu, relu, relu_sq")
120
+ p.add_argument("--behind-emb-dim", type=int, default=8,
121
+ help="Behind-gestalt embedding dim (0 to disable)")
122
+ p.add_argument("--vote-features", action="store_true",
123
+ help="Add n_views_voted + vote_frac as raw token features (requires v2 data)")
124
+ p.add_argument("--segment-param", choices=["midpoint_halfvec", "midpoint_dir_len"],
125
+ default="midpoint_halfvec",
126
+ help="Output parameterization: halfvec (default) or decoupled direction+length")
127
+ p.add_argument("--length-floor", type=float, default=0.0,
128
+ help="Minimum segment length for midpoint_dir_len (0=no floor)")
129
+ p.add_argument("--segment-conf", action="store_true",
130
+ help="Add per-segment confidence head (use with --conf-thresh at eval)")
131
+ p.add_argument("--conf-weight", type=float, default=0.0,
132
+ help="Weight for confidence loss (requires --segment-conf)")
133
+ p.add_argument("--conf-mode", choices=["sinkhorn", "sinkhorn_detach"], default="sinkhorn",
134
+ help="Confidence training: 'match'=BCE, 'sinkhorn'=OT mass, 'sinkhorn_detach'=OT mass (detached)")
135
+ p.add_argument("--conf-clamp-min", type=float, default=None,
136
+ help="Clamp conf logits to this minimum before sigmoid (e.g., -5)")
137
+ p.add_argument("--conf-head-wd", type=float, default=None,
138
+ help="Separate weight decay for conf head (default: same as other params)")
139
+ p.add_argument("--ema-decay", type=float, default=0.0,
140
+ help="EMA decay rate (0=disabled, try 0.9999). Saves EMA weights in checkpoints.")
141
+ p.add_argument("--out-dir", default=str(_Path(__file__).resolve().parent / "runs"))
142
+ p.add_argument("--resume", default="")
143
+ p.add_argument("--cpu", action="store_true")
144
+ p.add_argument("--args-from", default=None,
145
+ help="Load defaults from a run's args.json (CLI flags override)")
146
+
147
+ # If --args-from is specified, load defaults from that JSON file first,
148
+ # then let CLI flags override.
149
+ raw_args = p.parse_args()
150
+ if raw_args.args_from is not None:
151
+ import json as _json
152
+ args_path = _Path(raw_args.args_from)
153
+ if not args_path.exists():
154
+ raise FileNotFoundError(f"--args-from file not found: {args_path}")
155
+ saved = _json.loads(args_path.read_text())
156
+ valid_dests = {a.dest for a in p._actions}
157
+ defaults = {}
158
+ for k, v in saved.items():
159
+ if k in valid_dests and k != "args_from":
160
+ defaults[k] = v
161
+ p.set_defaults(**defaults)
162
+ args = p.parse_args()
163
+ print(f"Loaded defaults from {args_path} (CLI flags override)")
164
+ else:
165
+ args = raw_args
166
+
167
+ # Validate required args
168
+ if not args.cache_dir:
169
+ p.error("--cache-dir is required (either directly or via --args-from)")
170
+
171
+ # Validate arg compatibility
172
+ if args.arch == "transformer":
173
+ perceiver_only = []
174
+ if args.latent_tokens != 128:
175
+ perceiver_only.append(f"--latent-tokens={args.latent_tokens}")
176
+ if args.latent_layers != 7:
177
+ perceiver_only.append(f"--latent-layers={args.latent_layers}")
178
+ if args.pre_encoder_layers != 0:
179
+ perceiver_only.append(f"--pre-encoder-layers={args.pre_encoder_layers}")
180
+ if args.cross_attn_interval != 4:
181
+ perceiver_only.append(f"--cross-attn-interval={args.cross_attn_interval}")
182
+ if perceiver_only:
183
+ raise ValueError(
184
+ f"Args {', '.join(perceiver_only)} have no effect with --arch transformer. "
185
+ f"Use --arch perceiver or remove them.")
186
+ if args.conf_weight > 0 and not args.segment_conf:
187
+ raise ValueError("--conf-weight requires --segment-conf")
188
+ if args.conf_mode in ("sinkhorn", "sinkhorn_detach") and args.sinkhorn_weight == 0:
189
+ raise ValueError("--conf-mode sinkhorn requires --sinkhorn-weight > 0")
190
+ if args.cosine_decay and args.cooldown_start > 0:
191
+ raise ValueError("--cosine-decay and --cooldown-start are mutually exclusive")
192
+
193
+ device = torch.device("cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu"))
194
+ print(f"Device: {device}")
195
+ torch.manual_seed(args.seed)
196
+ np.random.seed(args.seed)
197
+
198
+ # Output
199
+ import hashlib, os
200
+ args_hash = hashlib.md5(json.dumps(vars(args), sort_keys=True).encode()).hexdigest()[:4]
201
+ run_tag = time.strftime("%Y%m%d_%H%M%S") + f"_{args_hash}_{os.getpid() % 10000:04d}"
202
+ out_dir = Path(args.out_dir) / run_tag
203
+ out_dir.mkdir(parents=True, exist_ok=True)
204
+ (out_dir / "checkpoints").mkdir(exist_ok=True)
205
+
206
+ # Tee stdout/stderr to run dir
207
+ import sys as _sys
208
+ _log_path = out_dir / "train.log"
209
+ class _Tee:
210
+ def __init__(self, path, stream):
211
+ self._file = open(path, "a")
212
+ self._stream = stream
213
+ def write(self, data):
214
+ self._stream.write(data)
215
+ self._file.write(data)
216
+ self._file.flush()
217
+ def flush(self):
218
+ self._stream.flush()
219
+ self._file.flush()
220
+ _sys.stdout = _Tee(_log_path, _sys.stdout)
221
+ _sys.stderr = _Tee(_log_path, _sys.stderr)
222
+
223
+ git_sha = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True,
224
+ cwd=str(_Path(__file__).parent)).stdout.strip()
225
+ git_dirty = subprocess.run(["git", "diff", "--quiet"], capture_output=True,
226
+ cwd=str(_Path(__file__).parent)).returncode != 0
227
+ run_info = {**vars(args), "git_sha": git_sha, "git_dirty": git_dirty}
228
+ (out_dir / "args.json").write_text(json.dumps(run_info, indent=2, sort_keys=True) + "\n")
229
+
230
+ # Set varifold cross-only mode before compile
231
+ if args.varifold_cross_only:
232
+ from . import losses as L
233
+ L.VARIFOLD_CROSS_ONLY = True
234
+ print("Varifold: cross-only mode (no self-energy)")
235
+
236
+ # Model
237
+ seq_len = args.seq_len
238
+ norm_class = torch.nn.RMSNorm if args.rms_norm else None
239
+ seq_cfg = EdgeDepthSequenceConfig(seq_len=seq_len)
240
+ model = EdgeDepthSegmentsModel(
241
+ seq_cfg=seq_cfg, segments=args.segments, hidden=args.hidden,
242
+ num_heads=args.num_heads, kv_heads_cross=args.kv_heads_cross,
243
+ kv_heads_self=args.kv_heads_self,
244
+ dim_feedforward=args.ff, dropout=args.dropout,
245
+ latent_tokens=args.latent_tokens, latent_layers=args.latent_layers,
246
+ decoder_layers=args.decoder_layers, cross_attn_interval=args.cross_attn_interval,
247
+ norm_class=norm_class, activation=args.activation,
248
+ segment_conf=args.segment_conf,
249
+ segment_param=args.segment_param,
250
+ length_floor=args.length_floor,
251
+ arch=args.arch, encoder_layers=args.encoder_layers,
252
+ pre_encoder_layers=args.pre_encoder_layers,
253
+ behind_emb_dim=args.behind_emb_dim,
254
+ use_vote_features=args.vote_features,
255
+ decoder_input_xattn=args.decoder_input_xattn,
256
+ qk_norm=args.qk_norm,
257
+ qk_norm_type=args.qk_norm_type,
258
+ learnable_fourier=args.learnable_fourier,
259
+ ).to(device)
260
+
261
+ try:
262
+ from torchinfo import summary
263
+ summary(model.segmenter,
264
+ input_data=[torch.zeros(1, seq_len, model.tokenizer.out_dim, device=device),
265
+ torch.ones(1, seq_len, device=device, dtype=torch.bool)],
266
+ col_names=("input_size", "output_size", "num_params"), verbose=1)
267
+ except ImportError:
268
+ pass
269
+ print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
270
+
271
+ # Compile (skip in deterministic mode for bit-reproducibility)
272
+ torch.set_float32_matmul_precision("high")
273
+ if args.deterministic:
274
+ torch.use_deterministic_algorithms(True)
275
+ torch.backends.cudnn.deterministic = True
276
+ torch.backends.cudnn.benchmark = False
277
+ import os
278
+ os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")
279
+ print("Deterministic mode: no torch.compile, bit-reproducible but ~3x slower")
280
+ elif device.type == "cuda":
281
+ model.segmenter = torch.compile(model.segmenter, mode="reduce-overhead", fullgraph=True)
282
+ from . import losses as L
283
+ L._loss_fn = torch.compile(_loss_inner, mode="reduce-overhead", fullgraph=True)
284
+ print("Compiled model + loss (reduce-overhead, fullgraph)")
285
+
286
+ # EMA
287
+ ema_model = None
288
+ if args.ema_decay > 0:
289
+ from copy import deepcopy
290
+ ema_model = deepcopy(model).eval()
291
+ for p_ema in ema_model.parameters():
292
+ p_ema.requires_grad_(False)
293
+ print(f"EMA enabled (decay={args.ema_decay})")
294
+
295
+ # Resume
296
+ start_step = 0
297
+ if args.resume:
298
+ ckpt = torch.load(args.resume, map_location=device, weights_only=False)
299
+ try:
300
+ model.load_state_dict(ckpt["model"])
301
+ except RuntimeError:
302
+ state = {k.replace("segmenter._orig_mod.", "segmenter."): v
303
+ for k, v in ckpt["model"].items()}
304
+ model.load_state_dict(state)
305
+ start_step = ckpt.get("step", 0)
306
+ print(f"Resumed from {args.resume} at step {start_step}")
307
+
308
+ betas = tuple(float(x) for x in args.adam_betas.split(","))
309
+
310
+ # Optimizer: AdamW with optional separate conf_head weight decay
311
+ conf_wd = args.conf_head_wd if args.conf_head_wd is not None else args.weight_decay
312
+ if args.conf_head_wd is not None:
313
+ conf_decay_params = []
314
+ other_params = []
315
+ for name, param in model.named_parameters():
316
+ if not param.requires_grad:
317
+ continue
318
+ if 'conf_head' in name:
319
+ conf_decay_params.append(param)
320
+ else:
321
+ other_params.append(param)
322
+ param_groups = [
323
+ {"params": other_params, "weight_decay": args.weight_decay},
324
+ {"params": conf_decay_params, "weight_decay": conf_wd},
325
+ ]
326
+ print(f"Conf head WD: {conf_wd} ({len(conf_decay_params)} params)")
327
+ else:
328
+ param_groups = model.parameters()
329
+
330
+ opt = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay,
331
+ betas=betas)
332
+ if args.resume and "optimizer" in ckpt:
333
+ opt.load_state_dict(ckpt["optimizer"])
334
+
335
+ # Data
336
+ torch.manual_seed(args.seed + 7919)
337
+ np.random.seed(args.seed + 7919)
338
+ train_loader = build_loader(args.cache_dir, args.batch_size, aug_rotate=args.aug_rotate,
339
+ aug_jitter=args.aug_jitter, aug_drop=args.aug_drop,
340
+ aug_flip=args.aug_flip)
341
+ val_loader = build_loader(args.val_cache_dir, args.batch_size) if args.val_cache_dir else None
342
+ data_iter = iter(train_loader)
343
+
344
+ # Intervals
345
+ log_int = max(1, min(50, args.steps // 20))
346
+ ckpt_int = 5000
347
+ val_int = ckpt_int if val_loader else 0
348
+
349
+ # Training loop
350
+ global_step = start_step
351
+ loss_ema, loss_sq_ema = 0.0, 0.0
352
+ t_start = time.perf_counter()
353
+
354
+ print(f"Training for {args.steps} steps | {args.segments}seg "
355
+ f"{args.hidden}h {args.latent_tokens}x{args.latent_layers}L "
356
+ f"{args.decoder_layers}D")
357
+
358
+ # Pre-fetch first batch
359
+ try:
360
+ next_batch = next(data_iter)
361
+ except StopIteration:
362
+ data_iter = iter(train_loader)
363
+ next_batch = next(data_iter)
364
+
365
+ # Freeze GC after setup to eliminate stalls during training
366
+ gc.collect()
367
+ gc.freeze()
368
+ gc.disable()
369
+
370
+ amp_ctx = torch.autocast(device_type='cuda', dtype=torch.bfloat16,
371
+ enabled=(device.type == 'cuda'))
372
+
373
+ while global_step < args.steps:
374
+ tokens, masks, gt_list, scales, meta = build_tokens(next_batch, model, device)
375
+
376
+ # Epsilon annealing
377
+ if args.sinkhorn_eps_start is not None and args.sinkhorn_eps_start != args.sinkhorn_eps:
378
+ if args.sinkhorn_eps_schedule == "sqrt":
379
+ ratio_sq = (args.sinkhorn_eps_start / args.sinkhorn_eps) ** 2
380
+ t0 = max(args.steps * 0.8 / max(ratio_sq - 1, 1e-6), 1.0)
381
+ current_eps = args.sinkhorn_eps_start / math.sqrt(1 + global_step / t0)
382
+ current_eps = max(current_eps, args.sinkhorn_eps)
383
+ else:
384
+ frac = min(global_step / max(args.steps * 0.8, 1), 1.0)
385
+ current_eps = args.sinkhorn_eps_start + frac * (args.sinkhorn_eps - args.sinkhorn_eps_start)
386
+ else:
387
+ current_eps = args.sinkhorn_eps
388
+
389
+ with amp_ctx:
390
+ out = model.forward_tokens(tokens, masks)
391
+ pred = out["segments"]
392
+ conf = out.get("conf")
393
+
394
+ # Endpoint weight warmup
395
+ if args.endpoint_warmup > 0 and global_step < args.endpoint_warmup:
396
+ current_ep_w = args.endpoint_weight * global_step / args.endpoint_warmup
397
+ else:
398
+ current_ep_w = args.endpoint_weight
399
+
400
+ loss, terms = compute_loss(pred, gt_list, scales.to(device), device,
401
+ args.varifold_weight, args.sinkhorn_weight,
402
+ endpoint_w=current_ep_w,
403
+ conf_logits=conf, conf_weight=args.conf_weight,
404
+ conf_mode=args.conf_mode,
405
+ sinkhorn_eps=current_eps,
406
+ sinkhorn_iters=args.sinkhorn_iters,
407
+ sinkhorn_dustbin=args.sinkhorn_dustbin,
408
+ conf_clamp_min=args.conf_clamp_min)
409
+
410
+ loss_val = loss.item()
411
+ # Adaptive loss spike detection
412
+ if global_step < 100:
413
+ loss_ema = loss_val if global_step == start_step else 0.9 * loss_ema + 0.1 * loss_val
414
+ loss_sq_ema = loss_val**2 if global_step == start_step else 0.9 * loss_sq_ema + 0.1 * loss_val**2
415
+ else:
416
+ loss_ema = 0.99 * loss_ema + 0.01 * loss_val
417
+ loss_sq_ema = 0.99 * loss_sq_ema + 0.01 * loss_val**2
418
+ loss_std = max(math.sqrt(max(loss_sq_ema - loss_ema**2, 0)), 1e-6)
419
+ spike_thresh = loss_ema + 5 * loss_std
420
+
421
+ # Skip on total loss spike or NaN
422
+ if not math.isfinite(loss_val) or loss_val > max(spike_thresh, 0.5):
423
+ sample_ids = [m.get("sample_id", "?") for m in meta]
424
+ skip_reason = f"loss={loss_val:.2f} > thresh={spike_thresh:.2f}"
425
+ print(f"Step {global_step}: {skip_reason}, skipping (samples: {sample_ids[:3]})")
426
+ with open(out_dir / "skipped_samples.jsonl", "a") as f:
427
+ f.write(json.dumps({"step": global_step, "reason": skip_reason,
428
+ "samples": sample_ids}) + "\n")
429
+ try:
430
+ next_batch = next(data_iter)
431
+ except StopIteration:
432
+ data_iter = iter(train_loader)
433
+ next_batch = next(data_iter)
434
+ continue
435
+
436
+ opt.zero_grad()
437
+ loss.backward()
438
+
439
+ # Fetch next batch while GPU finishes backward
440
+ try:
441
+ next_batch = next(data_iter)
442
+ except StopIteration:
443
+ data_iter = iter(train_loader)
444
+ next_batch = next(data_iter)
445
+
446
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
447
+
448
+ # LR schedule: warmup -> constant -> optional cooldown or cosine
449
+ if global_step < args.warmup:
450
+ lr = args.lr * (global_step + 1) / max(1, args.warmup)
451
+ elif args.cosine_decay:
452
+ progress = (global_step - args.warmup) / max(1, args.steps - args.warmup)
453
+ lr = args.lr * (0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)))
454
+ elif args.cooldown_start > 0 and global_step >= args.cooldown_start:
455
+ progress = (global_step - args.cooldown_start) / max(1, args.cooldown_steps)
456
+ lr = args.lr * max(0.01, 1.0 - 0.99 * min(1.0, progress))
457
+ else:
458
+ lr = args.lr
459
+ for pg in opt.param_groups:
460
+ pg["lr"] = lr
461
+ opt.step()
462
+ global_step += 1
463
+
464
+ # EMA update
465
+ if ema_model is not None:
466
+ decay = args.ema_decay
467
+ with torch.no_grad():
468
+ for p_ema, p_model in zip(ema_model.parameters(), model.parameters()):
469
+ p_ema.lerp_(p_model, 1.0 - decay)
470
+
471
+ # Log
472
+ entry = {"step": global_step, "ts": time.time(), "loss": loss.item(), "lr": lr}
473
+ entry.update({k: v.item() for k, v in terms.items()})
474
+ if global_step % log_int == 0:
475
+ grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters()
476
+ if p.grad is not None) ** 0.5
477
+ entry["grad_norm"] = grad_norm
478
+
479
+ if global_step % log_int == 0:
480
+ ms = (time.perf_counter() - t_start) / log_int * 1000
481
+ t_start = time.perf_counter()
482
+ t_str = " ".join(f"{k}={v:.4f}" for k, v in terms.items())
483
+ print(f"[{global_step}/{args.steps}] loss={loss.item():.4f} {t_str} "
484
+ f"lr={lr:.2e} gnorm={entry.get('grad_norm', 0):.3f} [{ms:.0f}ms/step]")
485
+
486
+ if val_int > 0 and global_step % val_int == 0:
487
+ try:
488
+ vl_list = []
489
+ with torch.no_grad(), amp_ctx:
490
+ for vb in val_loader:
491
+ vt, vm, vg, vs, _ = build_tokens(vb, model, device)
492
+ vo = model.forward_tokens(vt, vm)
493
+ vl, _ = compute_loss(vo["segments"], vg, vs.to(device), device,
494
+ args.varifold_weight, args.sinkhorn_weight)
495
+ if math.isfinite(vl.item()):
496
+ vl_list.append(vl.item())
497
+ if vl_list:
498
+ val_loss = float(np.mean(vl_list))
499
+ print(f" val_loss={val_loss:.4f}")
500
+ entry["val_loss"] = val_loss
501
+ except Exception as e:
502
+ print(f" val eval failed: {e}")
503
+
504
+ # Write log entry
505
+ with open(out_dir / "history.jsonl", "a") as f:
506
+ f.write(json.dumps(entry) + "\n")
507
+
508
+ if global_step % ckpt_int == 0:
509
+ try:
510
+ gc.enable(); gc.collect(); gc.freeze(); gc.disable()
511
+ torch.cuda.empty_cache()
512
+ save_dict = {"step": global_step, "model": model.state_dict(),
513
+ "optimizer": opt.state_dict(), "args": vars(args)}
514
+ if ema_model is not None:
515
+ save_dict["ema_model"] = ema_model.state_dict()
516
+ torch.save(save_dict, out_dir / "checkpoints" / f"step{global_step:06d}.pt")
517
+ except Exception as e:
518
+ print(f" checkpoint save failed: {e}")
519
+
520
+ # Final save
521
+ save_dict = {"step": global_step, "model": model.state_dict(),
522
+ "optimizer": opt.state_dict(), "args": vars(args)}
523
+ if ema_model is not None:
524
+ save_dict["ema_model"] = ema_model.state_dict()
525
+ torch.save(save_dict, out_dir / "checkpoints" / "final.pt")
526
+ print(f"Done. {global_step} steps. Output: {out_dir}")
527
+
528
+
529
+ if __name__ == "__main__":
530
+ main()
s23dr_2026_example/varifold.py CHANGED
@@ -1,25 +1,11 @@
1
  import torch
2
 
3
  from .wire_varifold_kernels import (
4
- loss_semi_lobatto3,
5
- loss_semi_lobatto3_mix,
6
- loss_semi_lobatto3_mix_simple,
7
- loss_simpson3,
8
  loss_simpson3_batch,
9
- loss_simpson3_mix,
10
  loss_simpson3_mix_batch,
11
- loss_simpson3_lenpow,
12
- loss_simpson3_lenpow_mix,
13
- loss_semi_legendre,
14
  )
15
 
16
 
17
- def edges_to_segments(vertices, edges) -> torch.Tensor:
18
- verts = torch.as_tensor(vertices, dtype=torch.float32)
19
- idx = torch.as_tensor(edges, dtype=torch.long)
20
- return torch.stack([verts[idx[:, 0]], verts[idx[:, 1]]], dim=1)
21
-
22
-
23
  def segments_to_vertices_edges(segments: torch.Tensor):
24
  segs = torch.as_tensor(segments, dtype=torch.float32)
25
  vertices = segs.reshape(-1, 3)
@@ -27,52 +13,6 @@ def segments_to_vertices_edges(segments: torch.Tensor):
27
  return vertices, edges
28
 
29
 
30
- def varifold_loss(
31
- pred_segments: torch.Tensor,
32
- gt_segments: torch.Tensor,
33
- sigma: float = 0.1,
34
- variant: str = "semi_lobatto3",
35
- t_nodes01: torch.Tensor | None = None,
36
- t_w: torch.Tensor | None = None,
37
- sigmas: torch.Tensor | None = None,
38
- alpha: torch.Tensor | None = None,
39
- normalize_alpha: bool = True,
40
- len_pow: float | None = None,
41
- ) -> torch.Tensor:
42
- p_pred, q_pred = pred_segments[:, 0], pred_segments[:, 1]
43
- p_gt, q_gt = gt_segments[:, 0], gt_segments[:, 1]
44
-
45
- if variant == "semi_lobatto3":
46
- return loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
47
- if variant == "semi_lobatto3_mix":
48
- if sigmas is None or alpha is None:
49
- raise ValueError("sigmas and alpha are required for semi_lobatto3_mix")
50
- return loss_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
51
- if variant == "semi_lobatto3_mix_simple":
52
- if sigmas is None or alpha is None:
53
- raise ValueError("sigmas and alpha are required for semi_lobatto3_mix_simple")
54
- return loss_semi_lobatto3_mix_simple(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
55
- if variant == "simpson3":
56
- if sigmas is not None or alpha is not None:
57
- if sigmas is None or alpha is None:
58
- raise ValueError("sigmas and alpha are required for simpson3 mix")
59
- return loss_simpson3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
60
- return loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
61
- if variant == "simpson3_lenpow":
62
- if len_pow is None:
63
- len_pow = 1.0
64
- if sigmas is not None or alpha is not None:
65
- if sigmas is None or alpha is None:
66
- raise ValueError("sigmas and alpha are required for simpson3_lenpow mix")
67
- return loss_simpson3_lenpow_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, len_pow, normalize_alpha)
68
- return loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
69
- if variant == "semi_legendre":
70
- return loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
71
- if variant in ("centers", "segments_varifold", "semi_lobatto1"):
72
- return varifold_loss_centers(pred_segments, gt_segments, sigma)
73
- raise ValueError(f"Unknown varifold variant: {variant}")
74
-
75
-
76
  def varifold_loss_batch(
77
  pred_segments: torch.Tensor,
78
  gt_segments: torch.Tensor,
@@ -102,95 +42,12 @@ def varifold_loss_batch(
102
  if pred_weights is not None:
103
  w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
104
 
105
- if variant == "simpson3":
106
- if sigmas is not None or alpha is not None:
107
- if sigmas is None or alpha is None:
108
- raise ValueError("sigmas and alpha are required for simpson3 mix")
109
- return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
110
- return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)
111
-
112
- # Fallback to per-sample loop for unsupported variants.
113
- losses = []
114
- sigmas_t = None
115
- if sigmas is not None:
116
- sigmas_t = torch.as_tensor(sigmas, device=pred_segments.device, dtype=pred_segments.dtype)
117
- for idx in range(pred_segments.shape[0]):
118
- gt_b = gt_segments[idx]
119
- if gt_mask is not None:
120
- gt_b = gt_b[gt_mask[idx]]
121
- sigmas_i = sigmas
122
- if sigmas_t is not None and sigmas_t.ndim == 2:
123
- sigmas_i = sigmas_t[idx]
124
- losses.append(
125
- varifold_loss(
126
- pred_segments[idx],
127
- gt_b,
128
- sigma=sigma,
129
- variant=variant,
130
- t_nodes01=t_nodes01,
131
- t_w=t_w,
132
- sigmas=sigmas_i,
133
- alpha=alpha,
134
- normalize_alpha=normalize_alpha,
135
- len_pow=len_pow,
136
- )
137
- )
138
- return torch.stack(losses, dim=0)
139
-
140
-
141
- def varifold_loss_centers(
142
- pred_segments: torch.Tensor,
143
- gt_segments: torch.Tensor,
144
- sigma: float = 0.1,
145
- normalize_weights: bool = True,
146
- ) -> torch.Tensor:
147
- eps = 1e-8
148
- a_p, b_p = pred_segments[:, 0], pred_segments[:, 1]
149
- a_g, b_g = gt_segments[:, 0], gt_segments[:, 1]
150
-
151
- v_p = b_p - a_p
152
- v_g = b_g - a_g
153
- len_p = torch.linalg.norm(v_p, dim=-1)
154
- len_g = torch.linalg.norm(v_g, dim=-1)
155
-
156
- x_p = 0.5 * (a_p + b_p)
157
- x_g = 0.5 * (a_g + b_g)
158
-
159
- u_p = v_p / (len_p[:, None] + eps)
160
- u_g = v_g / (len_g[:, None] + eps)
161
-
162
- w_p = len_p
163
- w_g = len_g
164
- if normalize_weights:
165
- w_p = w_p / (w_p.sum() + eps)
166
- w_g = w_g / (w_g.sum() + eps)
167
-
168
- diff_pp = x_p[:, None, :] - x_p[None, :, :]
169
- diff_gg = x_g[:, None, :] - x_g[None, :, :]
170
- diff_pg = x_p[:, None, :] - x_g[None, :, :]
171
- d_pp = (diff_pp * diff_pp).sum(dim=-1)
172
- d_gg = (diff_gg * diff_gg).sum(dim=-1)
173
- d_pg = (diff_pg * diff_pg).sum(dim=-1)
174
-
175
- inv2s2 = 1.0 / (2.0 * sigma * sigma)
176
- k_pp = torch.exp(-d_pp * inv2s2)
177
- k_gg = torch.exp(-d_gg * inv2s2)
178
- k_pg = torch.exp(-d_pg * inv2s2)
179
-
180
- dot_pp = (u_p[:, None, :] * u_p[None, :, :]).sum(dim=-1)
181
- dot_gg = (u_g[:, None, :] * u_g[None, :, :]).sum(dim=-1)
182
- dot_pg = (u_p[:, None, :] * u_g[None, :, :]).sum(dim=-1)
183
-
184
- k_pp = k_pp * (dot_pp * dot_pp)
185
- k_gg = k_gg * (dot_gg * dot_gg)
186
- k_pg = k_pg * (dot_pg * dot_pg)
187
-
188
- wp_row = w_p[:, None]
189
- wp_col = w_p[None, :]
190
- wg_row = w_g[:, None]
191
- wg_col = w_g[None, :]
192
-
193
- a_pp = (wp_row * wp_col * k_pp).sum(dim=-1).sum(dim=-1)
194
- a_gg = (wg_row * wg_col * k_gg).sum(dim=-1).sum(dim=-1)
195
- a_pg = (w_p[:, None] * w_g[None, :] * k_pg).sum(dim=-1).sum(dim=-1)
196
- return a_pp + a_gg - 2.0 * a_pg
 
1
  import torch
2
 
3
  from .wire_varifold_kernels import (
 
 
 
 
4
  loss_simpson3_batch,
 
5
  loss_simpson3_mix_batch,
 
 
 
6
  )
7
 
8
 
 
 
 
 
 
 
9
  def segments_to_vertices_edges(segments: torch.Tensor):
10
  segs = torch.as_tensor(segments, dtype=torch.float32)
11
  vertices = segs.reshape(-1, 3)
 
13
  return vertices, edges
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def varifold_loss_batch(
17
  pred_segments: torch.Tensor,
18
  gt_segments: torch.Tensor,
 
42
  if pred_weights is not None:
43
  w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
44
 
45
+ if variant != "simpson3":
46
+ raise ValueError(
47
+ f"Unsupported varifold variant: {variant!r}. "
48
+ f"Only 'simpson3' is supported in batch mode.")
49
+ if sigmas is not None or alpha is not None:
50
+ if sigmas is None or alpha is None:
51
+ raise ValueError("sigmas and alpha are required for simpson3 mix")
52
+ return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
53
+ return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s23dr_2026_example/wire_varifold_kernels.py CHANGED
@@ -1,4 +1,3 @@
1
- import math
2
  import torch
3
 
4
  # -----------------------------
@@ -46,7 +45,7 @@ def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool):
46
  return sigmas_t, alpha_t
47
 
48
  # -----------------------------
49
- # 1) Simpson-3 on both segments (3x3 product rule)
50
  # -----------------------------
51
  def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
52
  if w is None:
@@ -121,227 +120,9 @@ def cross_simpson3(
121
  return out[0] if not batched else out
122
 
123
 
124
- def cross_simpson3_lenpow(
125
- pA,
126
- qA,
127
- pB,
128
- qB,
129
- sigma: float | torch.Tensor,
130
- len_pow: float,
131
- wA: torch.Tensor | None = None,
132
- wB: torch.Tensor | None = None,
133
- ):
134
- device, dtype = pA.device, pA.dtype
135
- batched = pA.dim() == 3
136
- if not batched:
137
- pA = pA.unsqueeze(0)
138
- qA = qA.unsqueeze(0)
139
- pB = pB.unsqueeze(0)
140
- qB = qB.unsqueeze(0)
141
- nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
142
- w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
143
-
144
- bsz, nA, _ = pA.shape
145
- nB = pB.shape[1]
146
- wA = _prep_weight(wA, nA, bsz, device, dtype)
147
- wB = _prep_weight(wB, nB, bsz, device, dtype)
148
-
149
- _, _, ellA, uA = segment_geom(pA, qA)
150
- _, _, ellB, uB = segment_geom(pB, qB)
151
-
152
- XA = sample_points(pA, qA, nodes) # (B,N,3,3)
153
- YB = sample_points(pB, qB, nodes) # (B,M,3,3)
154
-
155
- ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
156
- lenfac = (ellA[:, :, None] * ellB[:, None, :]).pow(len_pow)
157
- if wA is not None or wB is not None:
158
- if wA is None:
159
- wA = torch.ones((bsz, nA), device=device, dtype=dtype)
160
- if wB is None:
161
- wB = torch.ones((bsz, nB), device=device, dtype=dtype)
162
- lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
163
-
164
- diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
165
- r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
166
- sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
167
- if sigma_t.ndim == 0:
168
- inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
169
- else:
170
- if sigma_t.shape[0] != bsz:
171
- raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
172
- inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
173
- K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
174
-
175
- spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
176
- out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
177
- return out[0] if not batched else out
178
-
179
-
180
- # -----------------------------
181
- # 2/3) Semi-analytic in s, quadrature in t
182
- # - Lobatto-3 (endpoints+midpoint)
183
- # - Gauss-Legendre Q (nodes/weights passed in)
184
- # -----------------------------
185
- def cross_semi_analytic(pA, qA, pB, qB, sigma: float, t_nodes01: torch.Tensor, t_w: torch.Tensor):
186
- """
187
- Gaussian k_x. Integrate s exactly along A, integrate t numerically along B.
188
- t_nodes01, t_w: (Q,) nodes/weights on [0,1] (constants you pass in)
189
- """
190
- device, dtype = pA.device, pA.dtype
191
- t = t_nodes01.to(device=device, dtype=dtype) # (Q,)
192
- w = t_w.to(device=device, dtype=dtype) # (Q,)
193
-
194
- dA, aA, ellA, uA = segment_geom(pA, qA)
195
- dB, _, ellB, uB = segment_geom(pB, qB)
196
-
197
- # (N,M) factors
198
- ang = (uA @ uB.t()).pow(2)
199
- lenfac = ellA[:, None] * ellB[None, :]
200
-
201
- # r0: (N,M,3)
202
- r0 = pA[:, None, :] - pB[None, :, :]
203
-
204
- # r(t): (N,M,Q,3)
205
- r = r0[:, :, None, :] - t[None, None, :, None] * dB[None, :, None, :]
206
-
207
- # beta, r2: (N,M,Q)
208
- beta = (r * dA[:, None, None, :]).sum(dim=-1)
209
- r2 = (r * r).sum(dim=-1)
210
-
211
- # semi-analytic constants per A segment: shapes broadcast to (N,1,1)
212
- a = aA.clamp_min(1e-12)
213
- inv_a = (1.0 / a).view(-1, 1, 1)
214
- denom = (torch.sqrt(2.0 * a) * sigma).view(-1, 1, 1)
215
- pref = (math.sqrt(math.pi) * sigma / torch.sqrt(2.0 * a)).view(-1, 1, 1)
216
-
217
- # J(t): (N,M,Q)
218
- exp_term = torch.exp(-(r2 - (beta * beta) * inv_a) / (2.0 * sigma * sigma))
219
- erf1 = torch.special.erf((a.view(-1, 1, 1) + beta) / denom)
220
- erf0 = torch.special.erf(beta / denom)
221
- J = pref * (erf1 - erf0) * exp_term
222
-
223
- # integrate over t: (N,M)
224
- spatial = (J * w.view(1, 1, -1)).sum(dim=-1)
225
- return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
226
-
227
-
228
- def cross_semi_lobatto3(pA, qA, pB, qB, sigma: float):
229
- device, dtype = pA.device, pA.dtype
230
- t = LOBATTO3_NODES.to(device=device, dtype=dtype)
231
- w = LOBATTO3_W.to(device=device, dtype=dtype)
232
- return cross_semi_analytic(pA, qA, pB, qB, sigma, t, w)
233
-
234
-
235
- def cross_semi_lobatto3_mix(
236
- pA,
237
- qA,
238
- pB,
239
- qB,
240
- sigmas,
241
- alpha,
242
- normalize_alpha: bool = True,
243
- ):
244
- """
245
- Semi-analytic in s (along A), Lobatto-3 in t (along B), with a sigma mixture.
246
- """
247
- device, dtype = pA.device, pA.dtype
248
- t_nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
249
- t_w = LOBATTO3_W.to(device=device, dtype=dtype)
250
-
251
- sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
252
-
253
- dA, aA, ellA, uA = segment_geom(pA, qA)
254
- dB, _, ellB, uB = segment_geom(pB, qB)
255
-
256
- ang = (uA @ uB.t()).pow(2)
257
- lenfac = ellA[:, None] * ellB[None, :]
258
-
259
- r0 = pA[:, None, :] - pB[None, :, :]
260
-
261
- a = aA.clamp_min(1e-12)
262
- inv_a = (1.0 / a).view(-1, 1)
263
- sqrt_a = torch.sqrt(2.0 * a).clamp_min(1e-12)
264
-
265
- denom = (sqrt_a[:, None] * sigmas_t[None, :]).clamp_min(1e-12)
266
- pref = math.sqrt(math.pi) * sigmas_t[None, :] / sqrt_a[:, None]
267
- inv2s2 = (1.0 / (2.0 * sigmas_t * sigmas_t)).view(1, 1, -1)
268
-
269
- denom_nmS = denom[:, None, :]
270
- pref_nmS = pref[:, None, :]
271
- alpha_nmS = alpha_t.view(1, 1, -1)
272
- a_nm1 = a[:, None, None]
273
-
274
- spatial = torch.zeros((pA.shape[0], pB.shape[0]), device=device, dtype=dtype)
275
- for tk, wk in zip(t_nodes, t_w):
276
- r = r0 - tk * dB[None, :, :]
277
- beta = (r * dA[:, None, :]).sum(dim=-1)
278
- r2 = (r * r).sum(dim=-1)
279
- core = r2 - (beta * beta) * inv_a
280
-
281
- exp_term = torch.exp(-core[:, :, None] * inv2s2)
282
- erf1 = torch.special.erf((a_nm1 + beta[:, :, None]) / denom_nmS)
283
- erf0 = torch.special.erf(beta[:, :, None] / denom_nmS)
284
- J = pref_nmS * (erf1 - erf0) * exp_term
285
- spatial = spatial + wk * (J * alpha_nmS).sum(dim=-1)
286
-
287
- return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
288
-
289
-
290
  # -----------------------------
291
- # Full losses (self + self - 2 cross)
292
  # -----------------------------
293
- # def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
294
- # s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
295
- # # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
296
- # cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
297
- # # return s_pred + s_gt - 2.0 * cross
298
- # return s_pred - 2.0 * cross
299
-
300
-
301
- def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
302
- s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
303
- # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
304
- cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
305
- # return s_pred + s_gt - 2.0 * cross
306
- return s_pred - 2.0 * cross
307
-
308
-
309
- def loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma: float, len_pow: float):
310
- s_pred = cross_simpson3_lenpow(p_pred, q_pred, p_pred, q_pred, sigma, len_pow)
311
- # s_gt = cross_simpson3_lenpow(p_gt, q_gt, p_gt, q_gt, sigma, len_pow)
312
- cross = cross_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
313
- # return s_pred + s_gt - 2.0 * cross
314
- return s_pred - 2.0 * cross
315
-
316
- def loss_simpson3_mix(
317
- p_pred,
318
- q_pred,
319
- p_gt,
320
- q_gt,
321
- sigmas,
322
- alpha,
323
- normalize_alpha: bool = True,
324
- ):
325
- device, dtype = p_pred.device, p_pred.dtype
326
- sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
327
- losses = [loss_simpson3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
328
- return (torch.stack(losses) * alpha_t).sum()
329
-
330
-
331
- # def loss_simpson3_batch(
332
- # p_pred: torch.Tensor,
333
- # q_pred: torch.Tensor,
334
- # p_gt: torch.Tensor,
335
- # q_gt: torch.Tensor,
336
- # sigma: float | torch.Tensor,
337
- # w_gt: torch.Tensor | None = None,
338
- # ) -> torch.Tensor:
339
- # s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
340
- # # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma, wA=w_gt, wB=w_gt)
341
- # cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wB=w_gt)
342
- # # return s_pred + s_gt - 2.0 * cross
343
- # return s_pred - 2.0 * cross
344
-
345
 
346
  def loss_simpson3_batch(
347
  p_pred: torch.Tensor,
@@ -385,77 +166,3 @@ def loss_simpson3_mix_batch(
385
  losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
386
  return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
387
  raise ValueError("sigmas must be 1D or 2D for batch loss")
388
-
389
-
390
- def loss_simpson3_lenpow_mix(
391
- p_pred,
392
- q_pred,
393
- p_gt,
394
- q_gt,
395
- sigmas,
396
- alpha,
397
- len_pow: float,
398
- normalize_alpha: bool = True,
399
- ):
400
- device, dtype = p_pred.device, p_pred.dtype
401
- sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
402
- losses = [loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, s, len_pow) for s in sigmas_t]
403
- return (torch.stack(losses) * alpha_t).sum()
404
-
405
- def loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma: float):
406
- s_pred = cross_semi_lobatto3(p_pred, q_pred, p_pred, q_pred, sigma)
407
- # s_gt = cross_semi_lobatto3(p_gt, q_gt, p_gt, q_gt, sigma)
408
- cross = cross_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
409
- # return s_pred + s_gt - 2.0 * cross
410
- return s_pred - 2.0 * cross
411
-
412
-
413
- def loss_semi_lobatto3_mix(
414
- p_pred,
415
- q_pred,
416
- p_gt,
417
- q_gt,
418
- sigmas,
419
- alpha,
420
- normalize_alpha: bool = True,
421
- ):
422
- s_pred = cross_semi_lobatto3_mix(p_pred, q_pred, p_pred, q_pred, sigmas, alpha, normalize_alpha)
423
- # s_gt = cross_semi_lobatto3_mix(p_gt, q_gt, p_gt, q_gt, sigmas, alpha, normalize_alpha)
424
- cross = cross_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
425
- # return s_pred + s_gt - 2.0 * cross
426
- return s_pred - 2.0 * cross
427
-
428
- def loss_semi_lobatto3_mix_simple(
429
- p_pred,
430
- q_pred,
431
- p_gt,
432
- q_gt,
433
- sigmas,
434
- alpha,
435
- normalize_alpha: bool = True,
436
- ):
437
- device, dtype = p_pred.device, p_pred.dtype
438
- sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
439
- losses = [loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
440
- return (torch.stack(losses) * alpha_t).sum()
441
-
442
- def loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma: float, t_nodes01, t_w):
443
- s_pred = cross_semi_analytic(p_pred, q_pred, p_pred, q_pred, sigma, t_nodes01, t_w)
444
- s_gt = cross_semi_analytic(p_gt, q_gt, p_gt, q_gt, sigma, t_nodes01, t_w)
445
- cross = cross_semi_analytic(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
446
- return s_pred + s_gt - 2.0 * cross
447
-
448
-
449
- # -----------------------------
450
- # torch.compile usage
451
- # -----------------------------
452
- # For Legendre: generate nodes/weights ONCE outside compile and pass them in.
453
- # Example:
454
- # import numpy as np
455
- # x,w = np.polynomial.legendre.leggauss(Q)
456
- # t_nodes = torch.tensor(0.5*(x+1.0), device=device, dtype=dtype)
457
- # t_w = torch.tensor(0.5*w, device=device, dtype=dtype)
458
- #
459
- # compiled_loss = torch.compile(loss_semi_lobatto3, fullgraph=True)
460
- # compiled_loss_leg = torch.compile(lambda pp,qp,pg,qg,s: loss_semi_legendre(pp,qp,pg,qg,s,t_nodes,t_w),
461
- # fullgraph=True)
 
 
1
  import torch
2
 
3
  # -----------------------------
 
45
  return sigmas_t, alpha_t
46
 
47
  # -----------------------------
48
+ # Simpson-3 on both segments (3x3 product rule)
49
  # -----------------------------
50
  def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
51
  if w is None:
 
120
  return out[0] if not batched else out
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # -----------------------------
124
+ # Batch losses
125
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def loss_simpson3_batch(
128
  p_pred: torch.Tensor,
 
166
  losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
167
  return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
168
  raise ValueError("sigmas must be 1D or 2D for batch loss")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
script.py CHANGED
@@ -34,14 +34,14 @@ from s23dr_2026_example.make_sampled_cache import _priority_sample
34
  # Tokenizer / model imports
35
  from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
36
  from s23dr_2026_example.model import EdgeDepthSegmentsModel
37
- from s23dr_2026_example.segment_postprocess import merge_vertices, merge_vertices_iterative
38
  from s23dr_2026_example.varifold import segments_to_vertices_edges
39
  from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
40
 
41
- SEQ_LEN = 2048
42
- COLMAP_QUOTA = 1536
43
- DEPTH_QUOTA = 512
44
- CONF_THRESH = 0.7
45
  MERGE_THRESH = 0.4
46
  SNAP_RADIUS = 0.5
47
 
 
34
  # Tokenizer / model imports
35
  from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
36
  from s23dr_2026_example.model import EdgeDepthSegmentsModel
37
+ from s23dr_2026_example.segment_postprocess import merge_vertices_iterative
38
  from s23dr_2026_example.varifold import segments_to_vertices_edges
39
  from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
40
 
41
+ SEQ_LEN = 4096
42
+ COLMAP_QUOTA = 3072
43
+ DEPTH_QUOTA = 1024
44
+ CONF_THRESH = 0.5
45
  MERGE_THRESH = 0.4
46
  SNAP_RADIUS = 0.5
47