Long Hoang commited on
Commit
3048b6a
·
1 Parent(s): dcfb1de

feat: add PLY to GLB conversion and update UI for dual display and download

Browse files
Files changed (1) hide show
  1. app.py +61 -6
app.py CHANGED
@@ -16,6 +16,8 @@ subprocess.run(
16
 
17
  import gradio as gr # import AFTER the pip install above
18
  from huggingface_hub import HfApi
 
 
19
 
20
  # install custom wheels for gaussian splatting
21
  subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
@@ -57,6 +59,43 @@ def get_dust3r_args_parser():
57
  return parser
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @spaces.GPU(duration=150)
61
  def process(inputfiles, input_path=None):
62
 
@@ -219,11 +258,18 @@ def process(inputfiles, input_path=None):
219
  print("Failed to upload PLY to hub:", e)
220
  ply_url = f"LOCAL:{output_ply_path}"
221
 
 
 
 
 
 
222
  # return:
223
  # 1) video path (for gr.Video)
224
  # 2) ply URL (for API + textbox)
225
- # 3) local ply path (for gr.Model3D viewer)
226
- return output_video_path, ply_url, output_ply_path
 
 
227
  ##################################################################################################################################################
228
 
229
 
@@ -269,8 +315,13 @@ with block:
269
  with gr.Tab("Output"):
270
  with gr.Column(scale=2):
271
  with gr.Group():
272
- output_model = gr.Model3D(
273
- label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
 
 
 
 
 
274
  interactive=False,
275
  camera_position=[0.5, 0.5, 1],
276
  )
@@ -285,17 +336,21 @@ with block:
285
  label="PLY download URL",
286
  interactive=False,
287
  )
 
 
 
 
288
  with gr.Column(scale=1):
289
  output_video = gr.Video(label="video")
290
 
291
- button_gen.click(process, inputs=[inputfiles], outputs=[output_video, output_file, output_model])
292
 
293
  gr.Examples(
294
  examples=[
295
  "sora-santorini-3-views",
296
  ],
297
  inputs=[input_path],
298
- outputs=[output_video, output_file, output_model],
299
  fn=lambda x: process(inputfiles=None, input_path=x),
300
  cache_examples=True,
301
  label='Sparse-view Examples'
 
16
 
17
  import gradio as gr # import AFTER the pip install above
18
  from huggingface_hub import HfApi
19
+ import trimesh
20
+ from plyfile import PlyData
21
 
22
  # install custom wheels for gaussian splatting
23
  subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
 
59
  return parser
60
 
61
 
62
+ def convert_ply_to_glb(ply_path, glb_path):
63
+ try:
64
+ plydata = PlyData.read(ply_path)
65
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
66
+ np.asarray(plydata.elements[0]["y"]),
67
+ np.asarray(plydata.elements[0]["z"])), axis=1)
68
+
69
+ # Extract DC features (colors)
70
+ # f_dc_0, f_dc_1, f_dc_2 are SH coefficients for R, G, B
71
+ f_dc_0 = np.asarray(plydata.elements[0]["f_dc_0"])
72
+ f_dc_1 = np.asarray(plydata.elements[0]["f_dc_1"])
73
+ f_dc_2 = np.asarray(plydata.elements[0]["f_dc_2"])
74
+
75
+ # SH2RGB: sh * C0 + 0.5
76
+ # C0 = 0.28209479177387814
77
+ C0 = 0.28209479177387814
78
+ r = f_dc_0 * C0 + 0.5
79
+ g = f_dc_1 * C0 + 0.5
80
+ b = f_dc_2 * C0 + 0.5
81
+
82
+ colors = np.stack((r, g, b), axis=1)
83
+ # Clip to [0, 1]
84
+ colors = np.clip(colors, 0, 1)
85
+ # Convert to uint8 for trimesh
86
+ colors = (colors * 255).astype(np.uint8)
87
+
88
+ # Create PointCloud
89
+ pcd = trimesh.points.PointCloud(vertices=xyz, colors=colors)
90
+
91
+ # Export
92
+ pcd.export(glb_path)
93
+ return True
94
+ except Exception as e:
95
+ print(f"Error converting PLY to GLB: {e}")
96
+ return False
97
+
98
+
99
  @spaces.GPU(duration=150)
100
  def process(inputfiles, input_path=None):
101
 
 
258
  print("Failed to upload PLY to hub:", e)
259
  ply_url = f"LOCAL:{output_ply_path}"
260
 
261
+ # Convert PLY to GLB for visualization
262
+ output_glb_path = output_ply_path.replace('.ply', '.glb')
263
+ if not convert_ply_to_glb(output_ply_path, output_glb_path):
264
+ output_glb_path = None
265
+
266
  # return:
267
  # 1) video path (for gr.Video)
268
  # 2) ply URL (for API + textbox)
269
+ # 3) ply file path (for gr.File download)
270
+ # 4) ply file path (for gr.Model3D viewer)
271
+ # 5) glb file path (for gr.Model3D viewer)
272
+ return output_video_path, ply_url, output_ply_path, output_ply_path, output_glb_path
273
  ##################################################################################################################################################
274
 
275
 
 
315
  with gr.Tab("Output"):
316
  with gr.Column(scale=2):
317
  with gr.Group():
318
+ output_model_glb = gr.Model3D(
319
+ label="3D Model (GLB Point Cloud)",
320
+ interactive=False,
321
+ camera_position=[0.5, 0.5, 1],
322
+ )
323
+ output_model_ply = gr.Model3D(
324
+ label="Original PLY (Gaussian Splat)",
325
  interactive=False,
326
  camera_position=[0.5, 0.5, 1],
327
  )
 
336
  label="PLY download URL",
337
  interactive=False,
338
  )
339
+ output_download = gr.File(
340
+ label="Download PLY",
341
+ interactive=False,
342
+ )
343
  with gr.Column(scale=1):
344
  output_video = gr.Video(label="video")
345
 
346
+ button_gen.click(process, inputs=[inputfiles], outputs=[output_video, output_file, output_download, output_model_ply, output_model_glb])
347
 
348
  gr.Examples(
349
  examples=[
350
  "sora-santorini-3-views",
351
  ],
352
  inputs=[input_path],
353
+ outputs=[output_video, output_file, output_download, output_model_ply, output_model_glb],
354
  fn=lambda x: process(inputfiles=None, input_path=x),
355
  cache_examples=True,
356
  label='Sparse-view Examples'