| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """PyTorch Eureka-Audio model.""" |
| |
|
| | import os |
| | import logging |
| | from copy import deepcopy |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import ( |
| | PreTrainedModel, |
| | GenerationMixin, |
| | AutoConfig, |
| | AutoModelForCausalLM, |
| | ) |
| | from transformers.models.whisper.configuration_whisper import WhisperConfig |
| | from transformers.models.whisper.modeling_whisper import WhisperEncoder as TransformersWhisperEncoder |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.utils import logging as transformers_logging |
| |
|
| | from .configuration_eureka_audio import EurekaAudioConfig |
| |
|
| |
|
| | logger = transformers_logging.get_logger(__name__) |
| |
|
| |
|
| | class TokenType: |
| | """Token type identifiers for multimodal inputs.""" |
| | text = 0 |
| | audio = 3 |
| |
|
| | class WhisperEncoder(nn.Module): |
| | """ |
| | Whisper-based audio encoder for extracting audio features. |
| | |
| | Args: |
| | config: Whisper configuration dictionary |
| | """ |
| |
|
| | def __init__(self, config: dict): |
| | super().__init__() |
| | whisper_config = WhisperConfig(**config) |
| | whisper_config._attn_implementation = 'flash_attention_2' |
| | self.speech_encoder = TransformersWhisperEncoder(whisper_config) |
| |
|
| | def forward( |
| | self, |
| | mel_batch: torch.Tensor = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Encode mel spectrogram to audio features. |
| | |
| | Args: |
| | mel_batch: Precomputed mel spectrogram [B, 128, 3000] |
| | |
| | Returns: |
| | Audio features [1, T', D] where T' = B * 1500 and D = d_model |
| | """ |
| | if mel_batch is None: |
| | raise ValueError("mel_batch must be provided") |
| |
|
| | encoder_out = self.speech_encoder(mel_batch, return_dict=True).last_hidden_state |
| | |
| | final_audio_embedding = torch.cat([x for x in encoder_out], dim=0).unsqueeze(0) |
| | return final_audio_embedding |
| |
|
| | class AudioNanoExpert(nn.Module): |
| | """ |
| | Mixture of Experts adaptor for audio features. |
| | |
| | This module transforms audio encoder outputs to match the LLM hidden dimension |
| | using a sparse mixture of experts architecture. |
| | |
| | Args: |
| | config: EurekaAudioConfig containing nano_expert settings |
| | """ |
| |
|
| | def __init__(self, config: EurekaAudioConfig): |
| | super().__init__() |
| | cfg = config.audio_config["nano_expert"] |
| |
|
| | self.input_dim = cfg["input_dim"] |
| | self.expert_dim = cfg["expert_dim"] |
| | self.num_experts = cfg["num_experts"] |
| | self.k = cfg["k"] |
| | self.num_shared = cfg.get("num_shared_experts", 2) |
| | |
| | |
| | self.backbone_hidden_size = config.llm_config.get("hidden_size", 2048) |
| | self.output_dim = self.backbone_hidden_size |
| | self.proj_hidden = cfg.get("proj_hidden", 2560) |
| |
|
| | |
| | self.proj = nn.Sequential( |
| | nn.Linear(self.output_dim, self.proj_hidden), |
| | nn.SiLU(), |
| | nn.Linear(self.proj_hidden, self.backbone_hidden_size), |
| | nn.RMSNorm(self.backbone_hidden_size) |
| | ) |
| |
|
| | assert self.k > 0 and self.num_experts > self.num_shared |
| |
|
| | |
| | self.w_gating = nn.Linear(self.input_dim, self.num_experts - self.num_shared) |
| |
|
| | |
| | self.experts = nn.ModuleList([ |
| | nn.Sequential( |
| | nn.RMSNorm(self.input_dim), |
| | nn.Linear(self.input_dim, self.expert_dim), |
| | nn.SiLU(), |
| | nn.Linear(self.expert_dim, self.output_dim), |
| | nn.RMSNorm(self.output_dim) |
| | ) for _ in range(self.num_experts) |
| | ]) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Forward pass through MoE. |
| | |
| | Args: |
| | x: Input features [*, input_dim] |
| | |
| | Returns: |
| | Transformed features matching LLM hidden dimension |
| | """ |
| | flat_x = x.reshape(-1, x.shape[-1]) |
| | N = flat_x.shape[0] |
| |
|
| | |
| | logits = self.w_gating(flat_x) |
| | topk_vals, topk_idx = torch.topk(logits, self.k, dim=1) |
| | topk_scores = F.softmax(topk_vals, dim=1) |
| | topk_idx_shifted = topk_idx + self.num_shared |
| |
|
| | |
| | W_flat = torch.zeros(N, self.num_experts, device=flat_x.device, dtype=topk_scores.dtype) |
| | W_flat.scatter_(1, topk_idx_shifted, topk_scores) |
| |
|
| | |
| | dispatched = (W_flat.t().unsqueeze(-1) * flat_x.unsqueeze(0)) |
| | expert_out = torch.stack( |
| | [self.experts[e](dispatched[e]) for e in range(self.num_experts)], |
| | dim=0 |
| | ) |
| |
|
| | |
| | routed_out = (W_flat.unsqueeze(-1) * expert_out.permute(1, 0, 2)).sum(dim=1) |
| |
|
| | |
| | shared_out = sum(self.experts[e](flat_x) for e in range(self.num_shared)) |
| |
|
| | out = routed_out + shared_out |
| | out = out.view(-1, self.output_dim) |
| | out = self.proj(out) |
| | return out |
| |
|
| |
|
| | class EurekaAudioModel(PreTrainedModel): |
| | """ |
| | Base Eureka-Audio model outputting raw hidden-states. |
| | |
| | This model inherits from [`PreTrainedModel`]. Check the superclass documentation |
| | for the generic methods the library implements for all its model. |
| | |
| | Args: |
| | config ([`EurekaAudioConfig`]): Model configuration class with all the parameters of the model. |
| | """ |
| |
|
| | config_class = EurekaAudioConfig |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["WhisperEncoder", "AudioNanoExpert"] |
| |
|
| | def __init__(self, config: EurekaAudioConfig, **kwargs): |
| | super().__init__(config, **kwargs) |
| | self.config = config |
| |
|
| | |
| | self.backbone = self._build_llm_backbone() |
| |
|
| | |
| | self.audio_encoder = self._build_audio_encoder() |
| |
|
| | |
| | self.audio_moe_adaptor = AudioNanoExpert(deepcopy(config)) |
| |
|
| | def _build_llm_backbone(self) -> nn.Module: |
| | """Build LLM backbone from config.""" |
| | llm_config = self.config.llm_config |
| |
|
| | |
| | config_obj = AutoConfig.for_model(**llm_config) |
| |
|
| | |
| | backbone = AutoModelForCausalLM.from_config( |
| | config_obj, |
| | attn_implementation="flash_attention_2", |
| | ).to(torch.bfloat16) |
| | return backbone |
| |
|
| | def _build_audio_encoder(self) -> nn.Module: |
| | """Build Whisper audio encoder.""" |
| | audio_encoder_config = self.config.audio_encoder_config |
| | audio_encoder = WhisperEncoder(config=audio_encoder_config) |
| | return audio_encoder.to(torch.bfloat16) |
| |
|
| | def get_input_embeddings(self): |
| | return self.backbone.model.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.backbone.model.embed_tokens = value |
| |
|
| | def _audio_embedding_forward( |
| | self, |
| | token_type_ids: torch.Tensor, |
| | inputs_embeds: torch.Tensor, |
| | continuous_audio_features: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Inject audio features into input embeddings. |
| | |
| | Args: |
| | token_type_ids: Token type IDs indicating audio positions |
| | inputs_embeds: Text embeddings from backbone |
| | continuous_audio_features: Audio features from Whisper encoder |
| | |
| | Returns: |
| | Modified embeddings with audio features injected |
| | """ |
| | understand_mask = token_type_ids == TokenType.audio |
| |
|
| | b, s, d = continuous_audio_features.shape |
| | assert s % 4 == 0, "continuous_audio_features frames must be divisible by 4" |
| |
|
| | |
| | continuous_audio_features = continuous_audio_features.view(b, s // 4, d * 4) |
| | if continuous_audio_features.size(0) == 1: |
| | continuous_audio_features = continuous_audio_features.squeeze(0) |
| |
|
| | |
| | exp_feat = self.audio_moe_adaptor( |
| | continuous_audio_features.to(inputs_embeds.dtype) |
| | ) |
| | inputs_embeds[understand_mask] = exp_feat.to(inputs_embeds.dtype) |
| |
|
| | return inputs_embeds |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | mel_batch_list: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ): |
| | """ |
| | Forward pass of the base model. |
| | |
| | Args: |
| | input_ids: Input token IDs |
| | attention_mask: Attention mask |
| | position_ids: Position IDs |
| | past_key_values: Past key values for caching |
| | inputs_embeds: Pre-computed input embeddings |
| | use_cache: Whether to use caching |
| | output_attentions: Whether to output attentions |
| | output_hidden_states: Whether to output hidden states |
| | return_dict: Whether to return a dict |
| | token_type_ids: Token type IDs (text=0, audio=3) |
| | mel_batch_list: Mel spectrogram batch [B, 128, 3000] |
| | |
| | Returns: |
| | Model outputs with hidden states |
| | """ |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1: |
| | token_type_ids_inputs = token_type_ids[..., :-1] |
| | else: |
| | token_type_ids_inputs = token_type_ids |
| |
|
| | |
| | if inputs_embeds is None: |
| | inputs_embeds = self.backbone.model.embed_tokens(input_ids) |
| |
|
| | |
| | if mel_batch_list is not None and token_type_ids_inputs is not None: |
| | continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list) |
| |
|
| | |
| | real_frames = (token_type_ids_inputs == TokenType.audio).sum() |
| | continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :] |
| |
|
| | |
| | inputs_embeds = self._audio_embedding_forward( |
| | token_type_ids_inputs, |
| | inputs_embeds, |
| | continuous_audio_features, |
| | ) |
| |
|
| | |
| | outputs = self.backbone.model( |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | use_cache=use_cache, |
| | past_key_values=past_key_values, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | ) |
| |
|
| | return outputs |
| |
|
| |
|
| | class EurekaAudioForCausalLM(EurekaAudioModel, GenerationMixin): |
| | """ |
| | Eureka-Audio Model with a language modeling head for causal LM. |
| | |
| | This model supports both text-only generation and audio understanding tasks. |
| | |
| | Example: |
| | ```python |
| | >>> from transformers import AutoModelForCausalLM |
| | |
| | >>> model = AutoModelForCausalLM.from_pretrained( |
| | ... "cslys1999/Eureka-Audio-Instruct", |
| | ... trust_remote_code=True |
| | ... ) |
| | ``` |
| | """ |
| |
|
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: EurekaAudioConfig, **kwargs): |
| | super().__init__(config, **kwargs) |
| |
|
| | def get_output_embeddings(self): |
| | return self.backbone.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.backbone.lm_head = new_embeddings |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.LongTensor, |
| | **kwargs, |
| | ): |
| | """Prepare inputs for generation step.""" |
| | model_inputs = super().prepare_inputs_for_generation( |
| | input_ids, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | token_type_ids = model_inputs['token_type_ids'] |
| | token_type_ids = torch.cat([ |
| | token_type_ids, |
| | torch.zeros((token_type_ids.shape[0], 1), |
| | dtype=token_type_ids.dtype, |
| | device=token_type_ids.device), |
| | ], dim=-1) |
| | model_inputs['token_type_ids'] = token_type_ids |
| |
|
| | return model_inputs |
| |
|
| | def _update_model_kwargs_for_generation( |
| | self, |
| | outputs, |
| | model_kwargs, |
| | is_encoder_decoder: bool = False, |
| | ): |
| | """Update model kwargs for next generation step.""" |
| | model_kwargs = super()._update_model_kwargs_for_generation( |
| | outputs, |
| | model_kwargs, |
| | is_encoder_decoder=is_encoder_decoder, |
| | ) |
| | |
| | model_kwargs['audio_input_ids'] = None |
| | model_kwargs['mel_batch_list'] = None |
| | return model_kwargs |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | mel_batch_list: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | """ |
| | Forward pass for causal language modeling. |
| | |
| | Args: |
| | input_ids: Input token IDs [batch_size, seq_len] |
| | attention_mask: Attention mask [batch_size, seq_len] |
| | position_ids: Position IDs |
| | past_key_values: Past key values for caching |
| | inputs_embeds: Pre-computed input embeddings |
| | labels: Labels for computing the language modeling loss |
| | use_cache: Whether to use caching |
| | output_attentions: Whether to output attentions |
| | output_hidden_states: Whether to output hidden states |
| | return_dict: Whether to return a dict |
| | token_type_ids: Token type IDs (text=0, audio=3) |
| | mel_batch_list: Mel spectrogram batch [num_chunks, 128, 3000] |
| | |
| | Returns: |
| | CausalLMOutputWithPast with loss (if labels provided), logits, past_key_values, |
| | hidden_states, and attentions. |
| | """ |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | |
| | |
| | if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1: |
| | token_type_ids_inputs = token_type_ids[..., :-1] |
| | else: |
| | token_type_ids_inputs = token_type_ids |
| |
|
| | |
| | inputs_embeds = self.backbone.model.embed_tokens(input_ids) |
| |
|
| | |
| | if mel_batch_list is not None and token_type_ids is not None: |
| | continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list) |
| |
|
| | |
| | real_frames = (token_type_ids == TokenType.audio).sum() |
| | continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :] |
| |
|
| | inputs_embeds = self._audio_embedding_forward( |
| | token_type_ids_inputs, |
| | inputs_embeds, |
| | continuous_audio_features, |
| | ) |
| |
|
| | |
| | outputs = self.backbone( |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | use_cache=use_cache, |
| | past_key_values=past_key_values, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | ) |
| |
|
| | hidden_states = outputs.hidden_states[-1] |
| | logits = self.backbone.lm_head(hidden_states) |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct( |
| | shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1) |
| | ) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | |
| | EurekaAudioConfig.register_for_auto_class() |
| | EurekaAudioModel.register_for_auto_class("AutoModel") |
| | EurekaAudioForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
| |
|