MapleF9 commited on
Commit
1f07a2a
·
1 Parent(s): 9a6e346

update vtp-l

Browse files
Files changed (3) hide show
  1. README.md +246 -0
  2. config.json +48 -0
  3. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="figures/logo.png" alt="Logo" width="200"/>
4
+
5
+ <h2> Towards Scalable Pre-training of Visual Tokenizers for Generation </h2>
6
+
7
+ [Jingfeng Yao](https://github.com/JingfengYao)<sup>1</sup>, [Yuda Song](https://github.com/IDKiro)<sup>2</sup>, Yucong Zhou<sup>2</sup>, [Xinggang Wang](https://xwcv.github.io/)<sup>1,*</sup>
8
+
9
+ <sup>1</sup>Huazhong University of Science and Technology
10
+ <sup>2</sup>MiniMax
11
+ <sup>*</sup>Corresponding author: [email protected]
12
+
13
+ ***Work still in Progress.***
14
+
15
+
16
+ [![arXiv](https://img.shields.io/badge/📖_paper-VTP-FF4040?style=flat-square&labelColor=2C3E50)](https://arxiv.org/abs/2512.13687)
17
+
18
+ <img src="figures/abs.png" alt="Abstract Figure" width="900"/>
19
+
20
+ </div>
21
+
22
+ ## News
23
+
24
+ - **[2025.12.16]** We release our [technical report](https://arxiv.org/abs/2512.13687) in ArXiv. Weights will be released very soon.
25
+
26
+ ## Takeaways
27
+
28
+
29
+ By integrating contrastive, self-supervised, and reconstruction learning, we have trained numerous visual tokenizers from scratch. We are seeking to unveil the novel scalability interlinking understanding, generation, and reconstruction.
30
+
31
+ - **Same FLOPs in DiT Training, VTP scaling helps better generation.**
32
+
33
+ - **Traditional auto-encoders CANNOT be scaled up for diffusion generative models.**
34
+
35
+ - **Understanding is the key driver for improving the learnability scaling.**
36
+
37
+ - **Parameter, data and training scalability can be seen while representation learning involved.**
38
+
39
+ <div align="center">
40
+ <img src="figures/scaling_v2.png" alt="Overview Figure" width="900"/>
41
+ </div>
42
+
43
+ ## Get Checkpoints
44
+
45
+ | Checkpoints |
46
+ |-------|
47
+ | [VTP-S/16](pretrained/vtp-s-hf) |
48
+ | [VTP-B/16](pretrained/vtp-b-hf) |
49
+ | [VTP-L/16](pretrained/vtp-l-hf) |
50
+
51
+ Weights will be released very soon.
52
+
53
+ <details>
54
+ <summary><b style="font-size: 1.1em;">🚀 Click Here to Quick Start </b></summary>
55
+
56
+ ```
57
+ pip install -r requirements.txt
58
+ ```
59
+
60
+ ```python
61
+ import torch
62
+ from PIL import Image
63
+ from torchvision import transforms
64
+
65
+ from vtp.models.vtp_hf import VTPConfig, VTPModel
66
+ from vtp.tokenizers import get_tokenizer
67
+
68
+ model = VTPModel.from_pretrained("pretrained/vtp-l-hf")
69
+ model.eval()
70
+
71
+ # print model parameters
72
+ def count_params(m): return sum(p.numel() for p in m.parameters()) / 1e6
73
+ print(f"Vision Encoder: {count_params(model.trunk):.1f}M")
74
+ print(f"Pixel Decoder: {count_params(model.pixel_decoder):.1f}M")
75
+ print(f"Text Encoder: {count_params(model.text_transformer):.1f}M")
76
+
77
+ preprocess = transforms.Compose([
78
+ transforms.Resize((256, 256)),
79
+ transforms.ToTensor(),
80
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
81
+ ])
82
+ image = preprocess(Image.open("figures/dog.png")).unsqueeze(0)
83
+
84
+ # ---------------------------------------------------------------------------------------
85
+ # use it as auto-encoder; rFID=0.36
86
+ # ---------------------------------------------------------------------------------------
87
+ denormalize = transforms.Normalize(
88
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
89
+ std=[1/0.229, 1/0.224, 1/0.225]
90
+ )
91
+ with torch.no_grad(), torch.autocast("cuda"):
92
+ latents = model.get_reconstruction_latents(image) # encode
93
+ recon = model.get_latents_decoded_images(latents) # decode
94
+ recon_image = denormalize(recon[0]).clamp(0, 1).permute(1, 2, 0).cpu().numpy()
95
+ Image.fromarray((recon_image * 255).astype("uint8")).save("output/reconstructed.png")
96
+
97
+
98
+ # ---------------------------------------------------------------------------------------
99
+ # use it as clip; zero-shot 78.2
100
+ # ---------------------------------------------------------------------------------------
101
+ tokenizer = get_tokenizer('ViT-B-32', context_length=model.config.text_context_length)
102
+ text = tokenizer(["a diagram", "a dog", "a cat", "a person"])
103
+ with torch.no_grad(), torch.autocast("cuda"):
104
+ image_features = model.get_clip_image_feature(image, normalize=True)
105
+ text_features = model.get_clip_text_feature(text, normalize=True)
106
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
107
+ print("Label probs:", [f"{p:.4f}" for p in text_probs[0].tolist()])
108
+
109
+ # ---------------------------------------------------------------------------------------
110
+ # use it as ssl feature extractor; linear probing 85.7
111
+ # ---------------------------------------------------------------------------------------
112
+ with torch.no_grad(), torch.autocast("cuda"):
113
+ # get last layer features (cls token + patch tokens)
114
+ features = model.get_last_layer_feature(image)
115
+ cls_token = features['cls_token'] # (B, 1024)
116
+ patch_tokens = features['patch_tokens'] # (B, 256, 1024) for 256x256 image
117
+
118
+ # or get intermediate layer features for linear probing
119
+ intermediate = model.get_intermediate_layers_feature(
120
+ image, n=4, return_class_token=True
121
+ ) # returns 4 x (patch_tokens, cls_token), each cls_token is (B, 1024)
122
+ for i in range(1, 5):
123
+ print('Last %d layers:' % i)
124
+ print('Patch tokens shape:', intermediate[-i][0].shape)
125
+ print('Cls token shape:', intermediate[-i][1].shape)
126
+ ```
127
+
128
+ </details>
129
+
130
+ ## Performance
131
+
132
+ <table>
133
+ <tr>
134
+ <th rowspan="2">Model</th>
135
+ <th colspan="2" style="text-align: center;">Understanding</th>
136
+ <th colspan="1" style="text-align: center;">Reconstruction</th>
137
+ <th colspan="1" style="text-align: center;">Generation</th>
138
+ </tr>
139
+ <tr>
140
+ <th style="text-align: center;">Zero-shot Acc.</th>
141
+ <th style="text-align: center;">Linear Probing</th>
142
+ <th style="text-align: center;">rFID</th>
143
+ <th style="text-align: center;">LightningDiT-XL 80ep<br>nocfg FID-50K</th>
144
+ </tr>
145
+ <tr><td><a href="https://github.com/mlfoundations/open_clip">OpenCLIP</a></td><td style="text-align: center;">74.0</td><td style="text-align: center;">-</td><td style="text-align: center;">-</td><td style="text-align: center;">-</td></tr>
146
+ <tr><td><a href="https://github.com/openai/CLIP">CLIP</a></td><td style="text-align: center;">75.5</td><td style="text-align: center;">-</td><td style="text-align: center;">-</td><td style="text-align: center;">-</td></tr>
147
+ <tr><td><a href="https://github.com/google-research/big_vision">SigLIP</a></td><td style="text-align: center;"><strong>80.5</strong></td><td style="text-align: center;">-</td><td style="text-align: center;">-</td><td style="text-align: center;">-</td></tr>
148
+ <tr><td><a href="https://github.com/facebookresearch/mae">MAE</a></td><td style="text-align: center;">-</td><td style="text-align: center;">85.9</td><td style="text-align: center;">-</td><td style="text-align: center;">-</td></tr>
149
+ <tr><td><a href="https://github.com/facebookresearch/dinov2">DINOv2</a></td><td style="text-align: center;">-</td><td style="text-align: center;"><strong>86.7</strong></td><td style="text-align: center;">-</td><td style="text-align: center;">-</td></tr>
150
+ <tr><td><a href="https://github.com/FoundationVision/UniTok">UniTok</a></td><td style="text-align: center;">70.8</td><td style="text-align: center;">-</td><td style="text-align: center;">0.41</td><td style="text-align: center;">-</td></tr>
151
+ <tr><td><a href="https://github.com/mit-han-lab/vila-u">VILA-U</a></td><td style="text-align: center;">73.3</td><td style="text-align: center;">-</td><td style="text-align: center;">1.80</td><td style="text-align: center;">-</td></tr>
152
+ <tr><td><a href="https://github.com/hustvl/LightningDiT">VA-VAE-f16d32</a></td><td style="text-align: center;">-</td><td style="text-align: center;">-</td><td style="text-align: center;">0.28</td><td style="text-align: center;">4.29</td></tr>
153
+ <tr><td><a href="https://github.com/hustvl/LightningDiT">VA-VAE-f16d64</a></td><td style="text-align: center;">-</td><td style="text-align: center;">-</td><td style="text-align: center;"><strong>0.15</strong></td><td style="text-align: center;">-</td></tr>
154
+ <tr><td><a href="https://github.com/bytetriper/RAE">RAE-f16d768</a></td><td style="text-align: center;">-</td><td style="text-align: center;">84.5</td><td style="text-align: center;">0.57</td><td style="text-align: center;">4.28</td></tr>
155
+ <tr><td><b>VTP-S-f16d64 (ours)</b></td><td style="text-align: center;">66.7</td><td style="text-align: center;">77.5</td><td style="text-align: center;">0.98</td><td style="text-align: center;">5.46</td></tr>
156
+ <tr><td><b>VTP-B-f16d64 (ours)</b></td><td style="text-align: center;">73.2</td><td style="text-align: center;">81.0</td><td style="text-align: center;">0.74</td><td style="text-align: center;">3.88</td></tr>
157
+ <tr><td><b>VTP-L-f16d64 (ours)</b></td><td style="text-align: center;">78.2</td><td style="text-align: center;">85.7</td><td style="text-align: center;">0.36</td><td style="text-align: center;"><strong>2.81</strong></td></tr>
158
+ </table>
159
+
160
+
161
+ ## Introduction
162
+
163
+ The quality of the latent space in visual tokenizers (e.g., VAEs) is crucial for modern generative models. However, the standard reconstruction-based training paradigm produces a latent space that is biased towards low-level information, leading to a foundation flaw: better pixel-level accuracy does not lead to higher-quality generation.
164
+ This implies that pouring extensive compute into visual tokenizer pre-training translates poorly to improved performance in generation.
165
+
166
+ We identify this as the **"pre-training scaling problem"** and suggest a necessary shift: to be effective for generation, a latent space must concisely represent high-level semantics.
167
+ We present visual tokenizer pre-training, **VTP**, a unified visual tokenizer pre-training framework, pioneering the joint optimization of image-text contrastive, self-supervised, and reconstruction losses. Our large-scale study reveals two principal findings: (1) understanding is a key driver of generation, and (2) much better scaling properties, where generative performance scales effectively with compute, parameters, and data allocated to the pretraining of the visual tokenizer. After large-scale pre-training, our tokenizer delivers a competitive profile (78.2 zero-shot accuracy, 0.36 rFID) and 3× faster convergence on generation compared to advanced distillation methods. More importantly, it scales effectively: without modifying standard DiT training specs, solely investing more FLOPS in pretraining VTP achieves 65.8\% FID improvement in downstream generation, while conventional autoencoder stagnates very early at 1/10 FLOPS.
168
+
169
+ <div align="center">
170
+ <img src="figures/overview.png" alt="Overview Figure" width="900"/>
171
+ </div>
172
+
173
+ ## Evaluation
174
+
175
+ #### Installation
176
+
177
+ ```bash
178
+ conda create -n vtp python=3.10
179
+ conda activate vtp
180
+ git submodule update --init --recursive
181
+ pip install -r requirements.txt
182
+ ```
183
+
184
+ #### Zero-shot Classification
185
+
186
+ Modify the corresponding paths in ``scripts/test_zero_shot_hf.sh``. Run:
187
+ ```
188
+ bash scripts/test_zero_shot_hf.sh
189
+ ```
190
+
191
+ #### Linear Probing Classification
192
+
193
+ Modify the corresponding paths in ``scripts/test_linear_probing_hf.sh``. Run:
194
+ ```
195
+ bash scripts/test_linear_probing_hf.sh
196
+ ```
197
+
198
+ #### ImageNet Reconstruction
199
+
200
+ Modify the corresponding paths in ``scripts/test_reconstruction_hf.sh``. Run:
201
+ ```
202
+ bash scripts/test_reconstruction_hf.sh
203
+ ```
204
+
205
+ #### ImageNet Generation
206
+
207
+ We use [LightningDiT](https://github.com/hustvl/LightningDiT) codes to evaluate our generation performance.
208
+
209
+ Feature extraction:
210
+ ```
211
+ bash generation/scripts/extract_features_vtp.sh generation/configs/train_vtp_l_dit_xl.yaml
212
+ ```
213
+
214
+ LightningDiT training:
215
+ ```
216
+ bash generation/scripts/train_lightningdit_vtp.sh generation/configs/train_vtp_l_dit_xl.yaml
217
+ ```
218
+
219
+
220
+ LightningDiT sampling:
221
+ ```
222
+ bash generation/scripts/inference_lightningdit_vtp.sh generation/configs/train_vtp_l_dit_xl.yaml
223
+ ```
224
+
225
+ ## Acknowledgements
226
+
227
+ Our pre-training codes are built upon [OpenCLIP](https://github.com/mlfoundations/open_clip) and [DINOv2](https://github.com/facebookresearch/dinov2). Our final model variant uses [DINOv3](https://github.com/facebookresearch/dinov3) architecture.
228
+
229
+ We use [LightningDiT](https://github.com/hustvl/LightningDiT) for generation evaluation.
230
+
231
+ Thanks for their great codes.
232
+
233
+ ## Citation
234
+
235
+ ```bibtex
236
+ @article{vtp,
237
+ title={Towards Scalable Pre-training of Visual Tokenizers for Generation},
238
+ author={Yao, Jingfeng and Song, Yuda and Zhou, Yucong and Wang, Xinggang},
239
+ journal={arXiv preprint arXiv:2512.13687},
240
+ year={2025}
241
+ }
242
+ ```
243
+
244
+ ## Contact Us
245
+
246
+ Contact us at [email protected].
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VTPModel"
4
+ ],
5
+ "decoder_depth": 24,
6
+ "decoder_embed_dim": 1024,
7
+ "decoder_ffn_layer": "swiglu",
8
+ "decoder_init_values": null,
9
+ "decoder_norm_layer": "layernorm",
10
+ "decoder_num_heads": 16,
11
+ "decoder_use_qk_norm": false,
12
+ "dtype": "float32",
13
+ "image_size": 256,
14
+ "init_logit_bias": null,
15
+ "init_logit_scale": null,
16
+ "model_type": "vtp",
17
+ "nonscalar_logit_scale": false,
18
+ "text_context_length": 77,
19
+ "text_depth": 12,
20
+ "text_embed_cls": false,
21
+ "text_embed_dim": 768,
22
+ "text_ls_init_value": null,
23
+ "text_mlp_ratio": 4.0,
24
+ "text_no_causal_mask": false,
25
+ "text_num_heads": 12,
26
+ "text_output_tokens": false,
27
+ "text_pad_id": 0,
28
+ "text_pool_type": "argmax",
29
+ "text_proj_bias": false,
30
+ "text_proj_type": "linear",
31
+ "text_quick_gelu": false,
32
+ "text_vocab_size": 49408,
33
+ "train_clip": true,
34
+ "train_reconstruction": true,
35
+ "transformers_version": "4.56.0.dev0",
36
+ "vision_bottleneck_ae_only": true,
37
+ "vision_clip_feat": "cls",
38
+ "vision_depth": 24,
39
+ "vision_embed_dim": 1024,
40
+ "vision_feature_bottleneck": 64,
41
+ "vision_ffn_layer": "swiglu",
42
+ "vision_init_values": null,
43
+ "vision_mlp_ratio": 4,
44
+ "vision_norm_layer": "rmsnorm",
45
+ "vision_num_heads": 16,
46
+ "vision_patch_size": 16,
47
+ "vision_use_qk_norm": false
48
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de5df2006083a9536c4d3ea36c6ae2181ec604e06550f1b2d7cece3b16aac32f
3
+ size 2926368188