# coding=utf-8 # Copyright 2026 ERNIE Team and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Eureka-Audio model.""" import os import logging from copy import deepcopy from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM, ) from transformers.models.whisper.configuration_whisper import WhisperConfig from transformers.models.whisper.modeling_whisper import WhisperEncoder as TransformersWhisperEncoder from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging as transformers_logging from .configuration_eureka_audio import EurekaAudioConfig logger = transformers_logging.get_logger(__name__) class TokenType: """Token type identifiers for multimodal inputs.""" text = 0 audio = 3 class WhisperEncoder(nn.Module): """ Whisper-based audio encoder for extracting audio features. Args: config: Whisper configuration dictionary """ def __init__(self, config: dict): super().__init__() whisper_config = WhisperConfig(**config) whisper_config._attn_implementation = 'flash_attention_2' self.speech_encoder = TransformersWhisperEncoder(whisper_config) def forward( self, mel_batch: torch.Tensor = None, ) -> torch.Tensor: """ Encode mel spectrogram to audio features. Args: mel_batch: Precomputed mel spectrogram [B, 128, 3000] Returns: Audio features [1, T', D] where T' = B * 1500 and D = d_model """ if mel_batch is None: raise ValueError("mel_batch must be provided") encoder_out = self.speech_encoder(mel_batch, return_dict=True).last_hidden_state # Concatenate all chunks into single sequence final_audio_embedding = torch.cat([x for x in encoder_out], dim=0).unsqueeze(0) return final_audio_embedding class AudioNanoExpert(nn.Module): """ Mixture of Experts adaptor for audio features. This module transforms audio encoder outputs to match the LLM hidden dimension using a sparse mixture of experts architecture. Args: config: EurekaAudioConfig containing nano_expert settings """ def __init__(self, config: EurekaAudioConfig): super().__init__() cfg = config.audio_config["nano_expert"] self.input_dim = cfg["input_dim"] self.expert_dim = cfg["expert_dim"] self.num_experts = cfg["num_experts"] self.k = cfg["k"] self.num_shared = cfg.get("num_shared_experts", 2) # Expert output dimension should match backbone hidden_size (2048) # The out_dim in config (1280) is actually the expert intermediate dim self.backbone_hidden_size = config.llm_config.get("hidden_size", 2048) self.output_dim = self.backbone_hidden_size self.proj_hidden = cfg.get("proj_hidden", 2560) # Output projection: Linear(2048->2560) -> SiLU -> Linear(2560->2048) -> RMSNorm self.proj = nn.Sequential( nn.Linear(self.output_dim, self.proj_hidden), nn.SiLU(), nn.Linear(self.proj_hidden, self.backbone_hidden_size), nn.RMSNorm(self.backbone_hidden_size) ) assert self.k > 0 and self.num_experts > self.num_shared # Gating network for routing self.w_gating = nn.Linear(self.input_dim, self.num_experts - self.num_shared) # Expert networks: RMSNorm(5120) -> Linear(5120->1280) -> SiLU -> Linear(1280->2048) -> RMSNorm(2048) self.experts = nn.ModuleList([ nn.Sequential( nn.RMSNorm(self.input_dim), nn.Linear(self.input_dim, self.expert_dim), nn.SiLU(), nn.Linear(self.expert_dim, self.output_dim), nn.RMSNorm(self.output_dim) ) for _ in range(self.num_experts) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through MoE. Args: x: Input features [*, input_dim] Returns: Transformed features matching LLM hidden dimension """ flat_x = x.reshape(-1, x.shape[-1]) N = flat_x.shape[0] # Compute gating scores logits = self.w_gating(flat_x) topk_vals, topk_idx = torch.topk(logits, self.k, dim=1) topk_scores = F.softmax(topk_vals, dim=1) topk_idx_shifted = topk_idx + self.num_shared # Build routing weights W_flat = torch.zeros(N, self.num_experts, device=flat_x.device, dtype=topk_scores.dtype) W_flat.scatter_(1, topk_idx_shifted, topk_scores) # Dispatch to experts dispatched = (W_flat.t().unsqueeze(-1) * flat_x.unsqueeze(0)) expert_out = torch.stack( [self.experts[e](dispatched[e]) for e in range(self.num_experts)], dim=0 ) # Combine routed expert outputs routed_out = (W_flat.unsqueeze(-1) * expert_out.permute(1, 0, 2)).sum(dim=1) # Add shared expert outputs shared_out = sum(self.experts[e](flat_x) for e in range(self.num_shared)) out = routed_out + shared_out out = out.view(-1, self.output_dim) out = self.proj(out) return out class EurekaAudioModel(PreTrainedModel): """ Base Eureka-Audio model outputting raw hidden-states. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model. Args: config ([`EurekaAudioConfig`]): Model configuration class with all the parameters of the model. """ config_class = EurekaAudioConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoder", "AudioNanoExpert"] def __init__(self, config: EurekaAudioConfig, **kwargs): super().__init__(config, **kwargs) self.config = config # Build LLM backbone self.backbone = self._build_llm_backbone() # Build audio encoder self.audio_encoder = self._build_audio_encoder() # Build audio adaptor self.audio_moe_adaptor = AudioNanoExpert(deepcopy(config)) def _build_llm_backbone(self) -> nn.Module: """Build LLM backbone from config.""" llm_config = self.config.llm_config # Create config directly from dict config_obj = AutoConfig.for_model(**llm_config) # Create model with bfloat16 dtype to support flash_attention_2 backbone = AutoModelForCausalLM.from_config( config_obj, attn_implementation="flash_attention_2", ).to(torch.bfloat16) return backbone def _build_audio_encoder(self) -> nn.Module: """Build Whisper audio encoder.""" audio_encoder_config = self.config.audio_encoder_config audio_encoder = WhisperEncoder(config=audio_encoder_config) return audio_encoder.to(torch.bfloat16) def get_input_embeddings(self): return self.backbone.model.embed_tokens def set_input_embeddings(self, value): self.backbone.model.embed_tokens = value def _audio_embedding_forward( self, token_type_ids: torch.Tensor, inputs_embeds: torch.Tensor, continuous_audio_features: torch.Tensor, ) -> torch.Tensor: """ Inject audio features into input embeddings. Args: token_type_ids: Token type IDs indicating audio positions inputs_embeds: Text embeddings from backbone continuous_audio_features: Audio features from Whisper encoder Returns: Modified embeddings with audio features injected """ understand_mask = token_type_ids == TokenType.audio b, s, d = continuous_audio_features.shape assert s % 4 == 0, "continuous_audio_features frames must be divisible by 4" # Downsample: 4 encoder frames -> 1 audio token continuous_audio_features = continuous_audio_features.view(b, s // 4, d * 4) if continuous_audio_features.size(0) == 1: continuous_audio_features = continuous_audio_features.squeeze(0) # Transform through MoE adaptor exp_feat = self.audio_moe_adaptor( continuous_audio_features.to(inputs_embeds.dtype) ) inputs_embeds[understand_mask] = exp_feat.to(inputs_embeds.dtype) return inputs_embeds def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_type_ids: Optional[torch.Tensor] = None, mel_batch_list: Optional[torch.Tensor] = None, **kwargs, ): """ Forward pass of the base model. Args: input_ids: Input token IDs attention_mask: Attention mask position_ids: Position IDs past_key_values: Past key values for caching inputs_embeds: Pre-computed input embeddings use_cache: Whether to use caching output_attentions: Whether to output attentions output_hidden_states: Whether to output hidden states return_dict: Whether to return a dict token_type_ids: Token type IDs (text=0, audio=3) mel_batch_list: Mel spectrogram batch [B, 128, 3000] Returns: Model outputs with hidden states """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Handle token_type_ids shape if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1: token_type_ids_inputs = token_type_ids[..., :-1] else: token_type_ids_inputs = token_type_ids # Get text embeddings if inputs_embeds is None: inputs_embeds = self.backbone.model.embed_tokens(input_ids) # Process audio features (only when mel_batch_list is provided) if mel_batch_list is not None and token_type_ids_inputs is not None: continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list) # Trim to actual audio frame count real_frames = (token_type_ids_inputs == TokenType.audio).sum() continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :] # Inject audio into embeddings inputs_embeds = self._audio_embedding_forward( token_type_ids_inputs, inputs_embeds, continuous_audio_features, ) # Forward through backbone outputs = self.backbone.model( position_ids=position_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=use_cache, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=True, ) return outputs class EurekaAudioForCausalLM(EurekaAudioModel, GenerationMixin): """ Eureka-Audio Model with a language modeling head for causal LM. This model supports both text-only generation and audio understanding tasks. Example: ```python >>> from transformers import AutoModelForCausalLM >>> model = AutoModelForCausalLM.from_pretrained( ... "cslys1999/Eureka-Audio-Instruct", ... trust_remote_code=True ... ) ``` """ _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: EurekaAudioConfig, **kwargs): super().__init__(config, **kwargs) def get_output_embeddings(self): return self.backbone.lm_head def set_output_embeddings(self, new_embeddings): self.backbone.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, **kwargs, ): """Prepare inputs for generation step.""" model_inputs = super().prepare_inputs_for_generation( input_ids, **kwargs, ) # Extend token_type_ids - get from model_inputs (updated by parent), not kwargs token_type_ids = model_inputs['token_type_ids'] token_type_ids = torch.cat([ token_type_ids, torch.zeros((token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device), ], dim=-1) model_inputs['token_type_ids'] = token_type_ids return model_inputs def _update_model_kwargs_for_generation( self, outputs, model_kwargs, is_encoder_decoder: bool = False, ): """Update model kwargs for next generation step.""" model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, ) # Clear audio_input_ids and mel_batch_list after first forward pass model_kwargs['audio_input_ids'] = None model_kwargs['mel_batch_list'] = None return model_kwargs def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_type_ids: Optional[torch.Tensor] = None, mel_batch_list: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Forward pass for causal language modeling. Args: input_ids: Input token IDs [batch_size, seq_len] attention_mask: Attention mask [batch_size, seq_len] position_ids: Position IDs past_key_values: Past key values for caching inputs_embeds: Pre-computed input embeddings labels: Labels for computing the language modeling loss use_cache: Whether to use caching output_attentions: Whether to output attentions output_hidden_states: Whether to output hidden states return_dict: Whether to return a dict token_type_ids: Token type IDs (text=0, audio=3) mel_batch_list: Mel spectrogram batch [num_chunks, 128, 3000] Returns: CausalLMOutputWithPast with loss (if labels provided), logits, past_key_values, hidden_states, and attentions. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Handle token_type_ids shape # When token_type_ids.shape[-1] == input_ids.shape[-1] + 1, slice it # Otherwise use it as is (for compatibility with different calling patterns) if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1: token_type_ids_inputs = token_type_ids[..., :-1] else: token_type_ids_inputs = token_type_ids # Get text embeddings inputs_embeds = self.backbone.model.embed_tokens(input_ids) # Process audio features (only on first forward pass when mel_batch_list is provided) if mel_batch_list is not None and token_type_ids is not None: continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list) # Use full token_type_ids for real_frames calculation real_frames = (token_type_ids == TokenType.audio).sum() continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :] inputs_embeds = self._audio_embedding_forward( token_type_ids_inputs, inputs_embeds, continuous_audio_features, ) # Forward through backbone outputs = self.backbone( position_ids=position_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=use_cache, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=True, ) hidden_states = outputs.hidden_states[-1] logits = self.backbone.lm_head(hidden_states) loss = None if labels is not None: # Shift for next token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # Register the model with AutoModel EurekaAudioConfig.register_for_auto_class() EurekaAudioModel.register_for_auto_class("AutoModel") EurekaAudioForCausalLM.register_for_auto_class("AutoModelForCausalLM")