Safetensors
Emu3p5VisionVQ
custom_code
aikx commited on
Commit
aceec64
·
verified ·
1 Parent(s): 91789bb

Upload 4 files

Browse files

Add files for HF version

config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Emu3p5VisionVQModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_emu3p5visionvq.Emu3p5VisionVQConfig",
7
+ "AutoModel": "modeling_emu3p5visionvq.Emu3p5VisionVQModel"
8
+ },
9
+ "attn_resolutions": [
10
+ 16
11
+ ],
12
+ "ch": 256,
13
+ "ch_mult": [
14
+ 1,
15
+ 1,
16
+ 2,
17
+ 2,
18
+ 4
19
+ ],
20
+ "codebook_size": 131072,
21
+ "double_z": false,
22
+ "dropout": 0.0,
23
+ "embed_dim": 256,
24
+ "in_channels": 3,
25
+ "model_type": "Emu3p5VisionVQ",
26
+ "num_res_blocks": 4,
27
+ "out_ch": 3,
28
+ "resolution": 256,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.51.0",
31
+ "z_channels": 256
32
+ }
configuration_emu3p5visionvq.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Emu3p5VisionVQ model configuration """
16
+
17
+ from typing import List
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Emu3p5VisionVQConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`Emu3p5VisionVQ`]. It is used to instantiate an video movq
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a configuration to the VQ model presented in Emu3p5 paper.
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+ Args:
34
+ codebook_size (`int`, *optional*, defaults to 32768):
35
+ Codebook size of the VQ model.
36
+ embed_dim (`int`, *optional*, defaults to 4):
37
+ Dimension of the quantized vector in codebook.
38
+ z_channels (`int`, *optional*, defaults to 4):
39
+ Dimension of the output channel of encoder and the input channel of decoder
40
+ double_z (`bool`, *optional*, defaults to False):
41
+ Whether double the output dim of the encoder.
42
+ in_channels (`int`, *optional*, defaults to 3):
43
+ Input channel of encoder.
44
+ out_channels (`int`, *optional*, defaults to 3):
45
+ Output channel of decoder.
46
+ temporal_downsample_factor (`int`, *optional*, defaults to 4):
47
+ Temporal downsample factor.
48
+ ch (`int`, *optional*, defaults to 256):
49
+ Basic channel number of the intermediate blocks.
50
+ ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
51
+ Channel scaling factor of the intermediate blocks.
52
+ num_res_blocks (`int`, *optional*, defaults to 2):
53
+ Residual block number in each stage.
54
+ attn_resolutions (`List[int]`, *optional*, defaults to 3):
55
+ Stage indices to apply attention.
56
+ dropout (`float`, *optional*, defaults to 0.0):
57
+ Dropout probability.
58
+ ```python
59
+ >>> from configuration_emu3p5visionvq import Emu3VisionVQConfig
60
+ >>> from modeling_emu3p5visionvq import Emu3VisionVQ
61
+ >>> # Initializing a video VQ model of Emu3 configuration
62
+ >>> configuration = Emu3VisionVQConfig()
63
+ >>> # Initializing a model from the Emu3 VQ model style configuration
64
+ >>> model = Emu3VisionVQModel(configuration)
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "Emu3p5VisionVQ"
70
+
71
+ def __init__(
72
+ self,
73
+ double_z: bool = False,
74
+ z_channels: int = 256,
75
+ resolution: int = 256,
76
+ in_channels: int = 3,
77
+ out_ch: int = 3,
78
+ ch: int = 256,
79
+ ch_mult: List[int] = [1, 1, 2, 2, 4],
80
+ num_res_blocks: int = 4,
81
+ attn_resolutions: List[int] = [16],
82
+ dropout: float = 0.0,
83
+ codebook_size: int = 131072,
84
+ embed_dim: int = 256,
85
+ **kwargs,
86
+ ):
87
+ super().__init__(**kwargs)
88
+
89
+ self.double_z = double_z
90
+ self.z_channels = z_channels
91
+ self.resolution = resolution
92
+ self.in_channels = in_channels
93
+ self.out_ch = out_ch
94
+ self.ch = ch
95
+ self.ch_mult = ch_mult
96
+ self.num_res_blocks = num_res_blocks
97
+ self.attn_resolutions = attn_resolutions
98
+ self.dropout = dropout
99
+
100
+ self.codebook_size = codebook_size
101
+ self.embed_dim = embed_dim
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a808dc617e489c37be7f51a63a0e7cb77a97a54ceb59f674255d2e7bc7b2c080
3
+ size 1821405084
modeling_emu3p5visionvq.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Emu3p5VisionVQ model """
16
+
17
+
18
+ import math
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import nn, einsum
23
+ from torch.nn import functional as F
24
+ from transformers.modeling_utils import PreTrainedModel
25
+
26
+ from .configuration_emu3p5visionvq import Emu3p5VisionVQConfig
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x * torch.sigmoid(x)
32
+
33
+
34
+ def Emu3p5VisionVQNormalize(in_channels):
35
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Emu3p5VisionVQUpsample(nn.Module):
39
+
40
+ def __init__(self, in_channels):
41
+ super().__init__()
42
+ self.conv = nn.Conv2d(
43
+ in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ )
49
+
50
+ def forward(self, x):
51
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Emu3p5VisionVQDownsample(nn.Module):
57
+
58
+ def __init__(self, in_channels):
59
+ super().__init__()
60
+ self.conv = nn.Conv2d(
61
+ in_channels,
62
+ in_channels,
63
+ kernel_size=3,
64
+ stride=2,
65
+ padding=0,
66
+ )
67
+
68
+ def forward(self, x):
69
+ pad = (0, 1, 0, 1)
70
+ x = F.pad(x, pad, mode="constant", value=0)
71
+ x = self.conv(x)
72
+ return x
73
+
74
+
75
+ class Emu3p5VisionVQResnetBlock(nn.Module):
76
+
77
+ def __init__(
78
+ self,
79
+ *,
80
+ in_channels: int,
81
+ out_channels: Optional[int] = None,
82
+ conv_shortcut: bool = False,
83
+ dropout: float = 0.0
84
+ ):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Emu3p5VisionVQNormalize(in_channels)
92
+ self.conv1 = nn.Conv2d(
93
+ in_channels,
94
+ out_channels,
95
+ kernel_size=3,
96
+ stride=1,
97
+ padding=1,
98
+ )
99
+ self.norm2 = Emu3p5VisionVQNormalize(out_channels)
100
+ self.dropout = nn.Dropout(dropout)
101
+ self.conv2 = nn.Conv2d(
102
+ out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1,
107
+ )
108
+
109
+ if self.in_channels != self.out_channels:
110
+ if self.use_conv_shortcut:
111
+ self.conv_shortcut = nn.Conv2d(
112
+ in_channels,
113
+ out_channels,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding=1,
117
+ )
118
+ else:
119
+ self.nin_shortcut = nn.Conv2d(
120
+ in_channels,
121
+ out_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0,
125
+ )
126
+
127
+ def forward(self, x):
128
+ h = x
129
+ h = self.norm1(h)
130
+ h = nonlinearity(h)
131
+ h = self.conv1(h)
132
+
133
+ h = self.norm2(h)
134
+ h = nonlinearity(h)
135
+ h = self.dropout(h)
136
+ h = self.conv2(h)
137
+
138
+ if self.in_channels != self.out_channels:
139
+ if self.use_conv_shortcut:
140
+ x = self.conv_shortcut(x)
141
+ else:
142
+ x = self.nin_shortcut(x)
143
+
144
+ return x + h
145
+
146
+
147
+ class Emu3p5VisionVQAttnBlock(nn.Module):
148
+
149
+ def __init__(self, in_channels):
150
+ super().__init__()
151
+ self.in_channels = in_channels
152
+
153
+ self.norm = Emu3p5VisionVQNormalize(in_channels)
154
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
155
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
156
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
157
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
158
+
159
+
160
+ def forward(self, x):
161
+ h_ = x
162
+ h_ = self.norm(h_)
163
+ q = self.q(h_)
164
+ k = self.k(h_)
165
+ v = self.v(h_)
166
+
167
+ # compute attention
168
+ b,c,h,w = q.shape
169
+ q = q.reshape(b, c, h * w)
170
+ q = q.permute(0, 2, 1) # b,hw,c
171
+ k = k.reshape(b, c, h * w) # b,c,hw
172
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
173
+ w_ = w_ * (int(c) ** (-0.5))
174
+ w_ = F.softmax(w_, dim=2)
175
+
176
+ # attend to values
177
+ v = v.reshape(b, c, h * w)
178
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
179
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
180
+ h_ = h_.reshape(b, c, h, w)
181
+
182
+ h_ = self.proj_out(h_)
183
+
184
+ return x + h_
185
+
186
+
187
+ class Emu3p5VisionVQEncoder(nn.Module):
188
+
189
+ def __init__(self, config: Emu3p5VisionVQConfig):
190
+ super().__init__()
191
+ self.ch = config.ch
192
+ self.num_resolutions = len(config.ch_mult)
193
+ self.num_res_blocks = config.num_res_blocks
194
+ self.in_channels = config.in_channels
195
+ self.resolution = config.resolution
196
+
197
+ # downsampling
198
+ self.conv_in = nn.Conv2d(
199
+ self.in_channels,
200
+ self.ch,
201
+ kernel_size=3,
202
+ stride=1,
203
+ padding=1,
204
+ )
205
+
206
+ curr_res = self.resolution
207
+
208
+ in_ch_mult = (1, ) + tuple(config.ch_mult)
209
+ self.down = nn.ModuleList()
210
+ for i_level in range(self.num_resolutions):
211
+ block = nn.ModuleList()
212
+ attn = nn.ModuleList()
213
+ block_in = config.ch * in_ch_mult[i_level]
214
+ block_out = config.ch * config.ch_mult[i_level]
215
+ for i_block in range(self.num_res_blocks):
216
+ block.append(
217
+ Emu3p5VisionVQResnetBlock(
218
+ in_channels=block_in,
219
+ out_channels=block_out,
220
+ dropout=config.dropout,
221
+ ),
222
+ )
223
+ block_in = block_out
224
+ if curr_res in config.attn_resolutions:
225
+ attn.append(Emu3p5VisionVQAttnBlock(block_in))
226
+
227
+ down = nn.Module()
228
+ down.block = block
229
+ down.attn = attn
230
+ if i_level != self.num_resolutions - 1:
231
+ down.downsample = Emu3p5VisionVQDownsample(block_in)
232
+ curr_res = curr_res // 2
233
+
234
+ self.down.append(down)
235
+
236
+ # middle
237
+ self.mid = nn.Module()
238
+ self.mid.block_1 = Emu3p5VisionVQResnetBlock(
239
+ in_channels=block_in,
240
+ out_channels=block_in,
241
+ dropout=config.dropout,
242
+ )
243
+ self.mid.attn_1 = Emu3p5VisionVQAttnBlock(block_in)
244
+ self.mid.block_2 = Emu3p5VisionVQResnetBlock(
245
+ in_channels=block_in,
246
+ out_channels=block_in,
247
+ dropout=config.dropout,
248
+ )
249
+
250
+ # end
251
+ self.norm_out = Emu3p5VisionVQNormalize(block_in)
252
+ self.conv_out = nn.Conv2d(
253
+ block_in,
254
+ 2 * config.z_channels if config.double_z else config.z_channels,
255
+ kernel_size=3,
256
+ stride=1,
257
+ padding=1,
258
+ )
259
+
260
+
261
+ def forward(self, x):
262
+ # downsampling
263
+ hs = [self.conv_in(x)]
264
+ for i_level in range(self.num_resolutions):
265
+ for i_block in range(self.num_res_blocks):
266
+ h = self.down[i_level].block[i_block](hs[-1])
267
+ if len(self.down[i_level].attn) > 0:
268
+ h = self.down[i_level].attn[i_block](h)
269
+ hs.append(h)
270
+
271
+ if i_level != self.num_resolutions - 1:
272
+ hs.append(self.down[i_level].downsample(hs[-1]))
273
+
274
+ # middle
275
+ h = hs[-1]
276
+ h = self.mid.block_1(h)
277
+ h = self.mid.attn_1(h)
278
+ h = self.mid.block_2(h)
279
+
280
+ # end
281
+ h = self.norm_out(h)
282
+ h = nonlinearity(h)
283
+ h = self.conv_out(h)
284
+ return h
285
+
286
+
287
+ class Emu3p5VisionVQDecoder(nn.Module):
288
+
289
+ def __init__(self, config: Emu3p5VisionVQConfig):
290
+ super().__init__()
291
+ self.ch = config.ch
292
+ self.num_resolutions = len(config.ch_mult)
293
+ self.num_res_blocks = config.num_res_blocks
294
+
295
+ self.resolution = config.resolution
296
+
297
+ # compute in_ch_mult, block_in and curr_res at lowest res
298
+ in_ch_mult = (1, ) + tuple(config.ch_mult)
299
+ block_in = config.ch * config.ch_mult[self.num_resolutions-1]
300
+
301
+ curr_res = config.resolution // 2 ** (self.num_resolutions - 1)
302
+ self.z_shape = (1, config.z_channels, curr_res, curr_res)
303
+
304
+ # z to block_in
305
+ self.conv_in = nn.Conv2d(
306
+ config.z_channels,
307
+ block_in,
308
+ kernel_size=3,
309
+ stride=1,
310
+ padding=1,
311
+ )
312
+
313
+ # middle
314
+ self.mid = nn.Module()
315
+ self.mid.block_1 = Emu3p5VisionVQResnetBlock(
316
+ in_channels=block_in,
317
+ out_channels=block_in,
318
+ dropout=config.dropout,
319
+ )
320
+ self.mid.attn_1 = Emu3p5VisionVQAttnBlock(block_in)
321
+ self.mid.block_2 = Emu3p5VisionVQResnetBlock(
322
+ in_channels=block_in,
323
+ out_channels=block_in,
324
+ dropout=config.dropout,
325
+ )
326
+
327
+ # upsampling
328
+ self.up = nn.ModuleList()
329
+ for i_level in reversed(range(self.num_resolutions)):
330
+ block = nn.ModuleList()
331
+ attn = nn.ModuleList()
332
+ block_out = config.ch * config.ch_mult[i_level]
333
+ for i_block in range(self.num_res_blocks + 1):
334
+ block.append(
335
+ Emu3p5VisionVQResnetBlock(
336
+ in_channels=block_in,
337
+ out_channels=block_out,
338
+ dropout=config.dropout,
339
+ ),
340
+ )
341
+ block_in = block_out
342
+ if curr_res in config.attn_resolutions:
343
+ attn.append(Emu3p5VisionVQAttnBlock(block_in))
344
+
345
+ up = nn.Module()
346
+ up.block = block
347
+ up.attn = attn
348
+ if i_level != 0:
349
+ up.upsample = Emu3p5VisionVQUpsample(block_in)
350
+ curr_res = curr_res * 2
351
+ self.up.insert(0, up) # prepend to get consistent order
352
+
353
+ # end
354
+ self.norm_out = Emu3p5VisionVQNormalize(block_in)
355
+ self.conv_out = nn.Conv2d(
356
+ block_in,
357
+ config.out_ch,
358
+ kernel_size=3,
359
+ stride=1,
360
+ padding=1,
361
+ )
362
+
363
+ def forward(self, z):
364
+ # z to block_in
365
+ h = self.conv_in(z)
366
+
367
+ # middle
368
+ h = self.mid.block_1(h)
369
+ h = self.mid.attn_1(h)
370
+ h = self.mid.block_2(h)
371
+
372
+ # upsampling
373
+ for i_level in reversed(range(self.num_resolutions)):
374
+ for i_block in range(self.num_res_blocks + 1):
375
+ h = self.up[i_level].block[i_block](h)
376
+ if len(self.up[i_level].attn) > 0:
377
+ h = self.up[i_level].attn[i_block](h)
378
+
379
+ if i_level != 0:
380
+ h = self.up[i_level].upsample(h)
381
+
382
+ h = self.norm_out(h)
383
+ h = nonlinearity(h)
384
+ h = self.conv_out(h)
385
+
386
+ return h
387
+
388
+
389
+ class Emu3p5VisionVQVectorQuantizer(nn.Module):
390
+
391
+ def __init__(self, config):
392
+ super().__init__()
393
+
394
+ self.n_e = config.codebook_size
395
+ self.e_dim = config.embed_dim
396
+
397
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
398
+
399
+ def forward(self, z):
400
+ # z: [b, d, h, w]
401
+ embedding = self.embedding.weight # [n, d]
402
+
403
+ # cal similarity
404
+ logits = torch.einsum("b d h w, n d -> b n h w", z, embedding)
405
+
406
+ # get max indices
407
+ ind = logits.argmax(dim=1) # [b, h, w]
408
+
409
+ # lookup embedding
410
+ z_q = embedding[ind] # [b, h, w, d]
411
+ z_q = z_q.permute(0, 3, 1, 2).contiguous() # -> [b, d, h, w]
412
+
413
+ return z_q, ind.flatten()
414
+
415
+ def get_codebook_entry(self, indices, shape=None):
416
+ # get quantized latent vectors
417
+ z_q = self.embedding(indices)
418
+
419
+ # shape should in B H W
420
+ if shape is not None:
421
+ if len(shape) == 3:
422
+ shape = shape + (self.e_dim, )
423
+
424
+ z_q = z_q.view(shape)
425
+
426
+ # reshape back to match original input shape
427
+ # b h w c -> b c h w
428
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
429
+
430
+ return z_q
431
+
432
+
433
+ class Emu3p5VisionVQPretrainedModel(PreTrainedModel):
434
+ """
435
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
436
+ models.
437
+ """
438
+
439
+ config_class = Emu3p5VisionVQConfig
440
+ base_model_prefix = "emu3p5visionvq"
441
+ main_input_name = "pixel_values"
442
+ _no_split_modules = ["Emu3p5VisionVQResnetBlock", "Emu3p5VisionVQAttnBlock"]
443
+
444
+ def _init_weights(self, module):
445
+ if isinstance(module, nn.Conv2d):
446
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
447
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
448
+ elif isinstance(module, nn.Linear):
449
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
450
+ if module.bias is not None:
451
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
452
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
453
+ nn.init.uniform_(module.bias, -bound, bound)
454
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
455
+ nn.init.constant_(module.weight, 1)
456
+ nn.init.constant_(module.bias, 0)
457
+
458
+
459
+ class Emu3p5VisionVQModel(Emu3p5VisionVQPretrainedModel):
460
+
461
+ def __init__(self, config):
462
+ super().__init__(config)
463
+ self.config = config
464
+
465
+ self.encoder = Emu3p5VisionVQEncoder(config)
466
+ self.decoder = Emu3p5VisionVQDecoder(config)
467
+ self.quantize = Emu3p5VisionVQVectorQuantizer(config)
468
+
469
+ self.quant_conv = nn.Conv2d(config.z_channels, config.embed_dim, 1)
470
+ self.post_quant_conv = nn.Conv2d(config.embed_dim, config.z_channels, 1)
471
+
472
+ self.post_init()
473
+
474
+ def encode(self, x: torch.Tensor):
475
+ h = self.encoder(x)
476
+ h = self.quant_conv(h)
477
+ quant_embed, token_ids = self.quantize(h)
478
+ return quant_embed, None, (None, None, token_ids)
479
+
480
+ def decode(self, x: torch.Tensor):
481
+ quant = self.post_quant_conv(x)
482
+ dec = self.decoder(quant)
483
+ return dec
484
+
485
+ def decode_code(self, code_b, shape=None):
486
+ # shape specifying (batch, height, width, channel)
487
+ quant_b = self.quantize.get_codebook_entry(code_b, shape=shape)
488
+ dec = self.decode(quant_b)
489
+ return dec
490
+
491
+ @property
492
+ def device(self):
493
+ return next(self.parameters()).device
494
+
495
+ @property
496
+ def dtype(self):
497
+ return next(self.parameters()).dtype