TRELLIS.2 / trellis2 /pipelines /trellis2_image_to_3d.py
JeffreyXiang's picture
Finalize
a1e3f5f
from typing import *
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from .base import Pipeline
from . import samplers, rembg
from ..modules.sparse import SparseTensor
from ..modules import image_feature_extractor
from ..representations import Mesh, MeshWithVoxel
class Trellis2ImageTo3DPipeline(Pipeline):
"""
Pipeline for inferring Trellis2 image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
shape_slat_sampler (samplers.Sampler): The sampler for the structured latent.
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler.
shape_slat_sampler_params (dict): The parameters for the structured latent sampler.
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
shape_slat_normalization (dict): The normalization parameters for the structured latent.
tex_slat_normalization (dict): The normalization parameters for the texture latent.
image_cond_model (Callable): The image conditioning model.
rembg_model (Callable): The model for removing background.
low_vram (bool): Whether to use low-VRAM mode.
"""
def __init__(
self,
models: dict[str, nn.Module] = None,
sparse_structure_sampler: samplers.Sampler = None,
shape_slat_sampler: samplers.Sampler = None,
tex_slat_sampler: samplers.Sampler = None,
sparse_structure_sampler_params: dict = None,
shape_slat_sampler_params: dict = None,
tex_slat_sampler_params: dict = None,
shape_slat_normalization: dict = None,
tex_slat_normalization: dict = None,
image_cond_model: Callable = None,
rembg_model: Callable = None,
low_vram: bool = True,
default_pipeline_type: str = '1024_cascade',
):
if models is None:
return
super().__init__(models)
self.sparse_structure_sampler = sparse_structure_sampler
self.shape_slat_sampler = shape_slat_sampler
self.tex_slat_sampler = tex_slat_sampler
self.sparse_structure_sampler_params = sparse_structure_sampler_params
self.shape_slat_sampler_params = shape_slat_sampler_params
self.tex_slat_sampler_params = tex_slat_sampler_params
self.shape_slat_normalization = shape_slat_normalization
self.tex_slat_normalization = tex_slat_normalization
self.image_cond_model = image_cond_model
self.rembg_model = rembg_model
self.low_vram = low_vram
self.default_pipeline_type = default_pipeline_type
self.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
self._device = 'cpu'
@staticmethod
def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path)
new_pipeline = Trellis2ImageTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
args = pipeline._pretrained_args
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
new_pipeline.low_vram = args.get('low_vram', True)
new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
new_pipeline.pbr_attr_layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
new_pipeline._device = 'cpu'
return new_pipeline
def to(self, device: torch.device) -> None:
self._device = device
if not self.low_vram:
super().to(device)
self.image_cond_model.to(device)
if self.rembg_model is not None:
self.rembg_model.to(device)
def preprocess_image(self, input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
input = input.convert('RGB')
if self.low_vram:
self.rembg_model.to(self.device)
output = self.rembg_model(input)
if self.low_vram:
self.rembg_model.cpu()
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
self.image_cond_model.image_size = resolution
if self.low_vram:
self.image_cond_model.to(self.device)
cond = self.image_cond_model(image)
if self.low_vram:
self.image_cond_model.cpu()
if not include_neg_cond:
return {'cond': cond}
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def sample_sparse_structure(
self,
cond: dict,
resolution: int,
num_samples: int = 1,
sampler_params: dict = {},
) -> torch.Tensor:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
resolution (int): The resolution of the sparse structure.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample sparse structure latent
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
in_channels = flow_model.in_channels
noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device)
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
z_s = self.sparse_structure_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling sparse structure",
).samples
if self.low_vram:
flow_model.cpu()
# Decode sparse structure latent
decoder = self.models['sparse_structure_decoder']
if self.low_vram:
decoder.to(self.device)
decoded = decoder(z_s)>0
if self.low_vram:
decoder.cpu()
if resolution != decoded.shape[2]:
ratio = decoded.shape[2] // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
return coords
def sample_shape_slat(
self,
cond: dict,
flow_model,
coords: torch.Tensor,
sampler_params: dict = {},
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
noise = SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.shape_slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling shape SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def sample_shape_slat_cascade(
self,
lr_cond: dict,
cond: dict,
flow_model_lr,
flow_model,
lr_resolution: int,
resolution: int,
coords: torch.Tensor,
sampler_params: dict = {},
max_num_tokens: int = 49152,
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# LR
noise = SparseTensor(
feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model_lr.to(self.device)
slat = self.shape_slat_sampler.sample(
flow_model_lr,
noise,
**lr_cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling shape SLat",
).samples
if self.low_vram:
flow_model_lr.cpu()
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
# Upsample
if self.low_vram:
self.models['shape_slat_decoder'].to(self.device)
self.models['shape_slat_decoder'].low_vram = True
hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4)
if self.low_vram:
self.models['shape_slat_decoder'].cpu()
self.models['shape_slat_decoder'].low_vram = False
hr_resolution = resolution
while True:
quant_coords = torch.cat([
hr_coords[:, :1],
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
coords = quant_coords.unique(dim=0)
num_tokens = coords.shape[0]
if num_tokens < max_num_tokens or hr_resolution == 1024:
if hr_resolution != resolution:
print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.")
break
hr_resolution -= 128
# Sample structured latent
noise = SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.shape_slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling shape SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat, hr_resolution
def decode_shape_slat(
self,
slat: SparseTensor,
resolution: int,
) -> Tuple[List[Mesh], List[SparseTensor]]:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
List[Mesh]: The decoded meshes.
List[SparseTensor]: The decoded substructures.
"""
self.models['shape_slat_decoder'].set_resolution(resolution)
if self.low_vram:
self.models['shape_slat_decoder'].to(self.device)
self.models['shape_slat_decoder'].low_vram = True
ret = self.models['shape_slat_decoder'](slat, return_subs=True)
if self.low_vram:
self.models['shape_slat_decoder'].cpu()
self.models['shape_slat_decoder'].low_vram = False
return ret
def sample_tex_slat(
self,
cond: dict,
flow_model,
shape_slat: SparseTensor,
sampler_params: dict = {},
) -> SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
shape_slat (SparseTensor): The structured latent for shape
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
shape_slat = (shape_slat - mean) / std
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
if self.low_vram:
flow_model.to(self.device)
slat = self.tex_slat_sampler.sample(
flow_model,
noise,
concat_cond=shape_slat,
**cond,
**sampler_params,
verbose=True,
tqdm_desc="Sampling texture SLat",
).samples
if self.low_vram:
flow_model.cpu()
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def decode_tex_slat(
self,
slat: SparseTensor,
subs: List[SparseTensor],
) -> SparseTensor:
"""
Decode the structured latent.
Args:
slat (SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
List[SparseTensor]: The decoded texture voxels
"""
if self.low_vram:
self.models['tex_slat_decoder'].to(self.device)
ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5
if self.low_vram:
self.models['tex_slat_decoder'].cpu()
return ret
@torch.no_grad()
def decode_latent(
self,
shape_slat: SparseTensor,
tex_slat: SparseTensor,
resolution: int,
) -> List[MeshWithVoxel]:
"""
Decode the latent codes.
Args:
shape_slat (SparseTensor): The structured latent for shape.
tex_slat (SparseTensor): The structured latent for texture.
resolution (int): The resolution of the output.
"""
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
tex_voxels = self.decode_tex_slat(tex_slat, subs)
out_mesh = []
for m, v in zip(meshes, tex_voxels):
m.fill_holes()
out_mesh.append(
MeshWithVoxel(
m.vertices, m.faces,
origin = [-0.5, -0.5, -0.5],
voxel_size = 1 / resolution,
coords = v.coords[:, 1:],
attrs = v.feats,
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
layout=self.pbr_attr_layout
)
)
return out_mesh
@torch.no_grad()
def run(
self,
image: Image.Image,
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
shape_slat_sampler_params: dict = {},
tex_slat_sampler_params: dict = {},
preprocess_image: bool = True,
return_latent: bool = False,
pipeline_type: Optional[str] = None,
max_num_tokens: int = 49152,
) -> List[MeshWithVoxel]:
"""
Run the pipeline.
Args:
image (Image.Image): The image prompt.
num_samples (int): The number of samples to generate.
seed (int): The random seed.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler.
tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
preprocess_image (bool): Whether to preprocess the image.
return_latent (bool): Whether to return the latent codes.
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
max_num_tokens (int): The maximum number of tokens to use.
"""
# Check pipeline type
pipeline_type = pipeline_type or self.default_pipeline_type
if pipeline_type == '512':
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found."
elif pipeline_type == '1024':
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
elif pipeline_type == '1024_cascade':
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
elif pipeline_type == '1536_cascade':
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
else:
raise ValueError(f"Invalid pipeline type: {pipeline_type}")
if preprocess_image:
image = self.preprocess_image(image)
torch.manual_seed(seed)
cond_512 = self.get_cond([image], 512)
cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
coords = self.sample_sparse_structure(
cond_512, ss_res,
num_samples, sparse_structure_sampler_params
)
if pipeline_type == '512':
shape_slat = self.sample_shape_slat(
cond_512, self.models['shape_slat_flow_model_512'],
coords, shape_slat_sampler_params
)
tex_slat = self.sample_tex_slat(
cond_512, self.models['tex_slat_flow_model_512'],
shape_slat, tex_slat_sampler_params
)
res = 512
elif pipeline_type == '1024':
shape_slat = self.sample_shape_slat(
cond_1024, self.models['shape_slat_flow_model_1024'],
coords, shape_slat_sampler_params
)
tex_slat = self.sample_tex_slat(
cond_1024, self.models['tex_slat_flow_model_1024'],
shape_slat, tex_slat_sampler_params
)
res = 1024
elif pipeline_type == '1024_cascade':
shape_slat, res = self.sample_shape_slat_cascade(
cond_512, cond_1024,
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
512, 1024,
coords, shape_slat_sampler_params,
max_num_tokens
)
tex_slat = self.sample_tex_slat(
cond_1024, self.models['tex_slat_flow_model_1024'],
shape_slat, tex_slat_sampler_params
)
elif pipeline_type == '1536_cascade':
shape_slat, res = self.sample_shape_slat_cascade(
cond_512, cond_1024,
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
512, 1536,
coords, shape_slat_sampler_params,
max_num_tokens
)
tex_slat = self.sample_tex_slat(
cond_1024, self.models['tex_slat_flow_model_1024'],
shape_slat, tex_slat_sampler_params
)
torch.cuda.empty_cache()
out_mesh = self.decode_latent(shape_slat, tex_slat, res)
if return_latent:
return out_mesh, (shape_slat, tex_slat, res)
else:
return out_mesh