| """ |
| Model Manager for real-time motion generation (HF Space version) |
| Loads model from Hugging Face Hub instead of local checkpoints. |
| """ |
|
|
| import threading |
| import time |
| from collections import deque |
|
|
| import numpy as np |
| import torch |
|
|
| from motion_process import StreamJointRecovery263 |
|
|
|
|
| class FrameBuffer: |
| """ |
| Thread-safe frame buffer that maintains a queue of generated frames |
| """ |
|
|
| def __init__(self, target_buffer_size=4): |
| self.buffer = deque(maxlen=100) |
| self.target_size = target_buffer_size |
| self.lock = threading.Lock() |
|
|
| def add_frame(self, joints): |
| """Add a frame to the buffer""" |
| with self.lock: |
| self.buffer.append(joints) |
|
|
| def get_frame(self): |
| """Get the next frame from buffer""" |
| with self.lock: |
| if len(self.buffer) > 0: |
| return self.buffer.popleft() |
| return None |
|
|
| def size(self): |
| """Get current buffer size""" |
| with self.lock: |
| return len(self.buffer) |
|
|
| def clear(self): |
| """Clear the buffer""" |
| with self.lock: |
| self.buffer.clear() |
|
|
| def needs_generation(self): |
| """Check if buffer needs more frames""" |
| return self.size() < self.target_size |
|
|
|
|
| class ModelManager: |
| """ |
| Manages model loading from HF Hub and real-time frame generation |
| """ |
|
|
| def __init__(self, model_name): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {self.device}") |
|
|
| |
| self.vae, self.model = self._load_models(model_name) |
|
|
| |
| self._base_schedule_config = { |
| "chunk_size": self.model.chunk_size, |
| "steps": self.model.noise_steps, |
| } |
| self._base_cfg_config = { |
| "cfg_scale": self.model.cfg_scale, |
| } |
|
|
| |
| self.frame_buffer = FrameBuffer(target_buffer_size=16) |
|
|
| |
| self.broadcast_frames = deque(maxlen=200) |
| self.broadcast_id = 0 |
| self.broadcast_lock = threading.Lock() |
|
|
| |
| self.smoothing_alpha = 0.5 |
| self.stream_recovery = StreamJointRecovery263( |
| joints_num=22, smoothing_alpha=self.smoothing_alpha |
| ) |
|
|
| |
| self.current_text = "" |
| self.is_generating = False |
| self.generation_thread = None |
| self.should_stop = False |
|
|
| |
| self.first_chunk = True |
| self._model_first_chunk = True |
| self.history_length = 30 |
|
|
| print("ModelManager initialized successfully") |
|
|
| def _patch_attention_sdpa(self, model_name): |
| """Patch flash_attention() to include SDPA fallback for GPUs without flash-attn (e.g., T4).""" |
| import glob |
| import os |
|
|
| hf_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") |
| patterns = [ |
| os.path.join( |
| hf_cache, "hub", "models--" + model_name.replace("/", "--"), |
| "snapshots", "*", "ldf_models", "tools", "attention.py", |
| ), |
| os.path.join( |
| hf_cache, "modules", "transformers_modules", model_name, |
| "*", "ldf_models", "tools", "attention.py", |
| ), |
| ] |
|
|
| |
| target = ( |
| ' assert q.device.type == "cuda" and q.size(-1) <= 256\n' |
| "\n" |
| " # params\n" |
| ) |
| replacement = ( |
| ' assert q.device.type == "cuda" and q.size(-1) <= 256\n' |
| "\n" |
| " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n" |
| " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n" |
| " out_dtype = q.dtype\n" |
| " b, lq, nq, c = q.shape\n" |
| " lk = k.size(1)\n" |
| " q = q.transpose(1, 2).to(dtype)\n" |
| " k = k.transpose(1, 2).to(dtype)\n" |
| " v = v.transpose(1, 2).to(dtype)\n" |
| " attn_mask = None\n" |
| " is_causal_flag = causal\n" |
| " if k_lens is not None:\n" |
| " k_lens = k_lens.to(q.device)\n" |
| " valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n" |
| " attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n" |
| " is_causal_flag = False\n" |
| " if causal:\n" |
| " cm = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)\n" |
| " attn_mask = attn_mask.masked_fill(cm[None, None, :, :], float('-inf'))\n" |
| " out = torch.nn.functional.scaled_dot_product_attention(\n" |
| " q, k, v, attn_mask=attn_mask, is_causal=is_causal_flag, dropout_p=dropout_p\n" |
| " )\n" |
| " return out.transpose(1, 2).contiguous().to(out_dtype)\n" |
| "\n" |
| " # params\n" |
| ) |
|
|
| for pattern in patterns: |
| for filepath in glob.glob(pattern): |
| with open(filepath, "r") as f: |
| content = f.read() |
| if "SDPA fallback" in content: |
| print(f"Already patched: {filepath}") |
| continue |
| if target in content: |
| content = content.replace(target, replacement, 1) |
| with open(filepath, "w") as f: |
| f.write(content) |
| print(f"Patched with SDPA fallback: {filepath}") |
|
|
| def _load_models(self, model_name): |
| """Load VAE and diffusion models from HF Hub""" |
| torch.set_float32_matmul_precision("high") |
|
|
| |
| print(f"Downloading model from HF Hub: {model_name}") |
| from huggingface_hub import snapshot_download |
| snapshot_download(model_name) |
|
|
| |
| self._patch_attention_sdpa(model_name) |
|
|
| print("Loading model...") |
| from transformers import AutoModel |
|
|
| hf_model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
| hf_model.to(self.device) |
|
|
| |
| print("Warming up model...") |
| _ = hf_model("test", length=1) |
|
|
| |
| model = hf_model.ldf_model |
| vae = hf_model.vae |
|
|
| model.eval() |
| vae.eval() |
|
|
| print("Models loaded successfully") |
| return vae, model |
|
|
| def start_generation(self, text, history_length=None): |
| """Start or update generation with new text""" |
| self.current_text = text |
|
|
| if history_length is not None: |
| self.history_length = history_length |
|
|
| if not self.is_generating: |
| |
| self.frame_buffer.clear() |
| self.stream_recovery.reset() |
| self.vae.clear_cache() |
| self.first_chunk = True |
| self._model_first_chunk = True |
| |
| self.model.chunk_size = self._base_schedule_config["chunk_size"] |
| self.model.noise_steps = self._base_schedule_config["steps"] |
| self.model.cfg_scale = self._base_cfg_config["cfg_scale"] |
| self.model.init_generated(self.history_length, batch_size=1) |
| print( |
| f"Model initialized with history length: {self.history_length}" |
| ) |
|
|
| |
| self.should_stop = False |
| self.generation_thread = threading.Thread(target=self._generation_loop) |
| self.generation_thread.daemon = True |
| self.generation_thread.start() |
| self.is_generating = True |
|
|
| def update_text(self, text): |
| """Update text without resetting state (continuous generation with new text)""" |
| if text != self.current_text: |
| old_text = self.current_text |
| self.current_text = text |
| |
| |
| print(f"Text updated: '{old_text}' -> '{text}' (continuous generation)") |
|
|
| def pause_generation(self): |
| """Pause generation (keeps all state)""" |
| self.should_stop = True |
| if self.generation_thread: |
| self.generation_thread.join(timeout=2.0) |
| self.is_generating = False |
| print("Generation paused (state preserved)") |
|
|
| def resume_generation(self): |
| """Resume generation from paused state""" |
| if self.is_generating: |
| print("Already generating, ignoring resume") |
| return |
|
|
| |
| self.should_stop = False |
| self.generation_thread = threading.Thread(target=self._generation_loop) |
| self.generation_thread.daemon = True |
| self.generation_thread.start() |
| self.is_generating = True |
| print("Generation resumed") |
|
|
| def reset(self, history_length=None, smoothing_alpha=None): |
| """Reset generation state completely |
| |
| Args: |
| history_length: History window length for the model |
| smoothing_alpha: EMA smoothing factor (0.0 to 1.0) |
| - 1.0 = no smoothing (default) |
| - 0.0 = infinite smoothing |
| - Recommended: 0.3-0.7 for visible smoothing |
| """ |
| |
| if self.is_generating: |
| self.pause_generation() |
|
|
| |
| self.frame_buffer.clear() |
| self.vae.clear_cache() |
| self.first_chunk = True |
|
|
| if history_length is not None: |
| self.history_length = history_length |
|
|
| |
| if smoothing_alpha is not None: |
| self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0) |
| print(f"Smoothing alpha updated to: {self.smoothing_alpha}") |
|
|
| |
| self.stream_recovery = StreamJointRecovery263( |
| joints_num=22, smoothing_alpha=self.smoothing_alpha |
| ) |
|
|
| |
| self.model.chunk_size = self._base_schedule_config["chunk_size"] |
| self.model.noise_steps = self._base_schedule_config["steps"] |
| self.model.cfg_scale = self._base_cfg_config["cfg_scale"] |
| self._model_first_chunk = True |
|
|
| |
| self.model.init_generated(self.history_length, batch_size=1) |
| print( |
| f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}" |
| ) |
|
|
| def _generation_loop(self): |
| """Main generation loop that runs in background thread""" |
| print("Generation loop started") |
|
|
| step_count = 0 |
| total_gen_time = 0 |
|
|
| with torch.no_grad(): |
| while not self.should_stop: |
| |
| if self.frame_buffer.needs_generation(): |
| try: |
| step_start = time.time() |
|
|
| |
| x = {"text": [self.current_text]} |
|
|
| |
| output = self.model.stream_generate_step( |
| x, first_chunk=self._model_first_chunk |
| ) |
| self._model_first_chunk = False |
| generated = output["generated"] |
|
|
| |
| if generated[0].shape[0] == 0: |
| continue |
|
|
| |
| decoded = self.vae.stream_decode( |
| generated[0][None, :], first_chunk=self.first_chunk |
| )[0] |
|
|
| self.first_chunk = False |
|
|
| |
| for i in range(decoded.shape[0]): |
| frame_data = decoded[i].cpu().numpy() |
| joints = self.stream_recovery.process_frame(frame_data) |
| self.frame_buffer.add_frame(joints) |
| |
| with self.broadcast_lock: |
| self.broadcast_id += 1 |
| self.broadcast_frames.append( |
| (self.broadcast_id, joints) |
| ) |
|
|
| step_time = time.time() - step_start |
| total_gen_time += step_time |
| step_count += 1 |
|
|
| |
| if step_count % 10 == 0: |
| avg_time = total_gen_time / step_count |
| fps = decoded.shape[0] / avg_time |
| print( |
| f"[Generation] Step {step_count}: {step_time * 1000:.1f}ms, " |
| f"Avg: {avg_time * 1000:.1f}ms, " |
| f"FPS: {fps:.1f}, " |
| f"Buffer: {self.frame_buffer.size()}" |
| ) |
|
|
| except Exception as e: |
| print(f"Error in generation: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
| time.sleep(0.1) |
| else: |
| |
| time.sleep(0.01) |
|
|
| print("Generation loop stopped") |
|
|
| def get_next_frame(self): |
| """Get the next frame from buffer""" |
| return self.frame_buffer.get_frame() |
|
|
| def get_broadcast_frames(self, after_id, count=8): |
| """Get frames from broadcast buffer after the given ID (for spectators).""" |
| with self.broadcast_lock: |
| frames = [ |
| (fid, joints) |
| for fid, joints in self.broadcast_frames |
| if fid > after_id |
| ] |
| return frames[:count] |
|
|
| def get_buffer_status(self): |
| """Get buffer status""" |
| return { |
| "buffer_size": self.frame_buffer.size(), |
| "target_size": self.frame_buffer.target_size, |
| "is_generating": self.is_generating, |
| "current_text": self.current_text, |
| "smoothing_alpha": self.smoothing_alpha, |
| "history_length": self.history_length, |
| "schedule_config": { |
| "chunk_size": self.model.chunk_size, |
| "steps": self.model.noise_steps, |
| }, |
| "cfg_config": { |
| "cfg_scale": self.model.cfg_scale, |
| }, |
| } |
|
|
|
|
| |
| _model_manager = None |
|
|
|
|
| def get_model_manager(model_name=None): |
| """Get or create the global model manager instance""" |
| global _model_manager |
| if _model_manager is None: |
| _model_manager = ModelManager(model_name) |
| return _model_manager |
|
|