Eureka-Audio-Instruct / modeling_eureka_audio.py
cslys1999's picture
Upload folder using huggingface_hub
e167993 verified
# coding=utf-8
# Copyright 2026 ERNIE Team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Eureka-Audio model."""
import os
import logging
from copy import deepcopy
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
GenerationMixin,
AutoConfig,
AutoModelForCausalLM,
)
from transformers.models.whisper.configuration_whisper import WhisperConfig
from transformers.models.whisper.modeling_whisper import WhisperEncoder as TransformersWhisperEncoder
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging as transformers_logging
from .configuration_eureka_audio import EurekaAudioConfig
logger = transformers_logging.get_logger(__name__)
class TokenType:
"""Token type identifiers for multimodal inputs."""
text = 0
audio = 3
class WhisperEncoder(nn.Module):
"""
Whisper-based audio encoder for extracting audio features.
Args:
config: Whisper configuration dictionary
"""
def __init__(self, config: dict):
super().__init__()
whisper_config = WhisperConfig(**config)
whisper_config._attn_implementation = 'flash_attention_2'
self.speech_encoder = TransformersWhisperEncoder(whisper_config)
def forward(
self,
mel_batch: torch.Tensor = None,
) -> torch.Tensor:
"""
Encode mel spectrogram to audio features.
Args:
mel_batch: Precomputed mel spectrogram [B, 128, 3000]
Returns:
Audio features [1, T', D] where T' = B * 1500 and D = d_model
"""
if mel_batch is None:
raise ValueError("mel_batch must be provided")
encoder_out = self.speech_encoder(mel_batch, return_dict=True).last_hidden_state
# Concatenate all chunks into single sequence
final_audio_embedding = torch.cat([x for x in encoder_out], dim=0).unsqueeze(0)
return final_audio_embedding
class AudioNanoExpert(nn.Module):
"""
Mixture of Experts adaptor for audio features.
This module transforms audio encoder outputs to match the LLM hidden dimension
using a sparse mixture of experts architecture.
Args:
config: EurekaAudioConfig containing nano_expert settings
"""
def __init__(self, config: EurekaAudioConfig):
super().__init__()
cfg = config.audio_config["nano_expert"]
self.input_dim = cfg["input_dim"]
self.expert_dim = cfg["expert_dim"]
self.num_experts = cfg["num_experts"]
self.k = cfg["k"]
self.num_shared = cfg.get("num_shared_experts", 2)
# Expert output dimension should match backbone hidden_size (2048)
# The out_dim in config (1280) is actually the expert intermediate dim
self.backbone_hidden_size = config.llm_config.get("hidden_size", 2048)
self.output_dim = self.backbone_hidden_size
self.proj_hidden = cfg.get("proj_hidden", 2560)
# Output projection: Linear(2048->2560) -> SiLU -> Linear(2560->2048) -> RMSNorm
self.proj = nn.Sequential(
nn.Linear(self.output_dim, self.proj_hidden),
nn.SiLU(),
nn.Linear(self.proj_hidden, self.backbone_hidden_size),
nn.RMSNorm(self.backbone_hidden_size)
)
assert self.k > 0 and self.num_experts > self.num_shared
# Gating network for routing
self.w_gating = nn.Linear(self.input_dim, self.num_experts - self.num_shared)
# Expert networks: RMSNorm(5120) -> Linear(5120->1280) -> SiLU -> Linear(1280->2048) -> RMSNorm(2048)
self.experts = nn.ModuleList([
nn.Sequential(
nn.RMSNorm(self.input_dim),
nn.Linear(self.input_dim, self.expert_dim),
nn.SiLU(),
nn.Linear(self.expert_dim, self.output_dim),
nn.RMSNorm(self.output_dim)
) for _ in range(self.num_experts)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through MoE.
Args:
x: Input features [*, input_dim]
Returns:
Transformed features matching LLM hidden dimension
"""
flat_x = x.reshape(-1, x.shape[-1])
N = flat_x.shape[0]
# Compute gating scores
logits = self.w_gating(flat_x)
topk_vals, topk_idx = torch.topk(logits, self.k, dim=1)
topk_scores = F.softmax(topk_vals, dim=1)
topk_idx_shifted = topk_idx + self.num_shared
# Build routing weights
W_flat = torch.zeros(N, self.num_experts, device=flat_x.device, dtype=topk_scores.dtype)
W_flat.scatter_(1, topk_idx_shifted, topk_scores)
# Dispatch to experts
dispatched = (W_flat.t().unsqueeze(-1) * flat_x.unsqueeze(0))
expert_out = torch.stack(
[self.experts[e](dispatched[e]) for e in range(self.num_experts)],
dim=0
)
# Combine routed expert outputs
routed_out = (W_flat.unsqueeze(-1) * expert_out.permute(1, 0, 2)).sum(dim=1)
# Add shared expert outputs
shared_out = sum(self.experts[e](flat_x) for e in range(self.num_shared))
out = routed_out + shared_out
out = out.view(-1, self.output_dim)
out = self.proj(out)
return out
class EurekaAudioModel(PreTrainedModel):
"""
Base Eureka-Audio model outputting raw hidden-states.
This model inherits from [`PreTrainedModel`]. Check the superclass documentation
for the generic methods the library implements for all its model.
Args:
config ([`EurekaAudioConfig`]): Model configuration class with all the parameters of the model.
"""
config_class = EurekaAudioConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoder", "AudioNanoExpert"]
def __init__(self, config: EurekaAudioConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
# Build LLM backbone
self.backbone = self._build_llm_backbone()
# Build audio encoder
self.audio_encoder = self._build_audio_encoder()
# Build audio adaptor
self.audio_moe_adaptor = AudioNanoExpert(deepcopy(config))
def _build_llm_backbone(self) -> nn.Module:
"""Build LLM backbone from config."""
llm_config = self.config.llm_config
# Create config directly from dict
config_obj = AutoConfig.for_model(**llm_config)
# Create model with bfloat16 dtype to support flash_attention_2
backbone = AutoModelForCausalLM.from_config(
config_obj,
attn_implementation="flash_attention_2",
).to(torch.bfloat16)
return backbone
def _build_audio_encoder(self) -> nn.Module:
"""Build Whisper audio encoder."""
audio_encoder_config = self.config.audio_encoder_config
audio_encoder = WhisperEncoder(config=audio_encoder_config)
return audio_encoder.to(torch.bfloat16)
def get_input_embeddings(self):
return self.backbone.model.embed_tokens
def set_input_embeddings(self, value):
self.backbone.model.embed_tokens = value
def _audio_embedding_forward(
self,
token_type_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
continuous_audio_features: torch.Tensor,
) -> torch.Tensor:
"""
Inject audio features into input embeddings.
Args:
token_type_ids: Token type IDs indicating audio positions
inputs_embeds: Text embeddings from backbone
continuous_audio_features: Audio features from Whisper encoder
Returns:
Modified embeddings with audio features injected
"""
understand_mask = token_type_ids == TokenType.audio
b, s, d = continuous_audio_features.shape
assert s % 4 == 0, "continuous_audio_features frames must be divisible by 4"
# Downsample: 4 encoder frames -> 1 audio token
continuous_audio_features = continuous_audio_features.view(b, s // 4, d * 4)
if continuous_audio_features.size(0) == 1:
continuous_audio_features = continuous_audio_features.squeeze(0)
# Transform through MoE adaptor
exp_feat = self.audio_moe_adaptor(
continuous_audio_features.to(inputs_embeds.dtype)
)
inputs_embeds[understand_mask] = exp_feat.to(inputs_embeds.dtype)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_type_ids: Optional[torch.Tensor] = None,
mel_batch_list: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Forward pass of the base model.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
position_ids: Position IDs
past_key_values: Past key values for caching
inputs_embeds: Pre-computed input embeddings
use_cache: Whether to use caching
output_attentions: Whether to output attentions
output_hidden_states: Whether to output hidden states
return_dict: Whether to return a dict
token_type_ids: Token type IDs (text=0, audio=3)
mel_batch_list: Mel spectrogram batch [B, 128, 3000]
Returns:
Model outputs with hidden states
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Handle token_type_ids shape
if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1:
token_type_ids_inputs = token_type_ids[..., :-1]
else:
token_type_ids_inputs = token_type_ids
# Get text embeddings
if inputs_embeds is None:
inputs_embeds = self.backbone.model.embed_tokens(input_ids)
# Process audio features (only when mel_batch_list is provided)
if mel_batch_list is not None and token_type_ids_inputs is not None:
continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list)
# Trim to actual audio frame count
real_frames = (token_type_ids_inputs == TokenType.audio).sum()
continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :]
# Inject audio into embeddings
inputs_embeds = self._audio_embedding_forward(
token_type_ids_inputs,
inputs_embeds,
continuous_audio_features,
)
# Forward through backbone
outputs = self.backbone.model(
position_ids=position_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
use_cache=use_cache,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=True,
)
return outputs
class EurekaAudioForCausalLM(EurekaAudioModel, GenerationMixin):
"""
Eureka-Audio Model with a language modeling head for causal LM.
This model supports both text-only generation and audio understanding tasks.
Example:
```python
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained(
... "cslys1999/Eureka-Audio-Instruct",
... trust_remote_code=True
... )
```
"""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: EurekaAudioConfig, **kwargs):
super().__init__(config, **kwargs)
def get_output_embeddings(self):
return self.backbone.lm_head
def set_output_embeddings(self, new_embeddings):
self.backbone.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
**kwargs,
):
"""Prepare inputs for generation step."""
model_inputs = super().prepare_inputs_for_generation(
input_ids,
**kwargs,
)
# Extend token_type_ids - get from model_inputs (updated by parent), not kwargs
token_type_ids = model_inputs['token_type_ids']
token_type_ids = torch.cat([
token_type_ids,
torch.zeros((token_type_ids.shape[0], 1),
dtype=token_type_ids.dtype,
device=token_type_ids.device),
], dim=-1)
model_inputs['token_type_ids'] = token_type_ids
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs,
is_encoder_decoder: bool = False,
):
"""Update model kwargs for next generation step."""
model_kwargs = super()._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=is_encoder_decoder,
)
# Clear audio_input_ids and mel_batch_list after first forward pass
model_kwargs['audio_input_ids'] = None
model_kwargs['mel_batch_list'] = None
return model_kwargs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_type_ids: Optional[torch.Tensor] = None,
mel_batch_list: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass for causal language modeling.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
position_ids: Position IDs
past_key_values: Past key values for caching
inputs_embeds: Pre-computed input embeddings
labels: Labels for computing the language modeling loss
use_cache: Whether to use caching
output_attentions: Whether to output attentions
output_hidden_states: Whether to output hidden states
return_dict: Whether to return a dict
token_type_ids: Token type IDs (text=0, audio=3)
mel_batch_list: Mel spectrogram batch [num_chunks, 128, 3000]
Returns:
CausalLMOutputWithPast with loss (if labels provided), logits, past_key_values,
hidden_states, and attentions.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Handle token_type_ids shape
# When token_type_ids.shape[-1] == input_ids.shape[-1] + 1, slice it
# Otherwise use it as is (for compatibility with different calling patterns)
if token_type_ids is not None and token_type_ids.shape[-1] == input_ids.shape[-1] + 1:
token_type_ids_inputs = token_type_ids[..., :-1]
else:
token_type_ids_inputs = token_type_ids
# Get text embeddings
inputs_embeds = self.backbone.model.embed_tokens(input_ids)
# Process audio features (only on first forward pass when mel_batch_list is provided)
if mel_batch_list is not None and token_type_ids is not None:
continuous_audio_features = self.audio_encoder(mel_batch=mel_batch_list)
# Use full token_type_ids for real_frames calculation
real_frames = (token_type_ids == TokenType.audio).sum()
continuous_audio_features = continuous_audio_features[:, :real_frames * 4, :]
inputs_embeds = self._audio_embedding_forward(
token_type_ids_inputs,
inputs_embeds,
continuous_audio_features,
)
# Forward through backbone
outputs = self.backbone(
position_ids=position_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
use_cache=use_cache,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=True,
)
hidden_states = outputs.hidden_states[-1]
logits = self.backbone.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Register the model with AutoModel
EurekaAudioConfig.register_for_auto_class()
EurekaAudioModel.register_for_auto_class("AutoModel")
EurekaAudioForCausalLM.register_for_auto_class("AutoModelForCausalLM")