| import torch |
| import inspect |
| import importlib |
|
|
| from typing import Callable, Optional, Union, Any, List |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.cache_utils import Cache |
| from transformers.processing_utils import Unpack |
|
|
| from .sep_cache_utils import SepCache |
|
|
|
|
|
|
| def truncate_input_ids_4_autoregression(input_ids, key_states): |
| if input_ids.shape[-1] != key_states.shape[-2]: |
| assert input_ids.shape[-1] >= key_states.shape[-2] |
| truncated_input_ids = input_ids[..., -key_states.shape[-2]: ] |
| return truncated_input_ids |
| else: |
| return input_ids |
|
|
| def llama_atten_forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| input_shape = hidden_states.shape[:-1] |
|
|
| if hasattr(self, "head_dim"): |
| head_dim = self.head_dim |
| elif hasattr(self, "head_size"): |
| head_dim = self.head_size |
|
|
| hidden_shape = (*input_shape, -1, head_dim) |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
|
| |
| assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`." |
| APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT |
| APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE |
| |
|
|
|
|
| |
| module = importlib.import_module(self.__module__) |
| |
| apply_rotary_pos_emb = module.apply_rotary_pos_emb |
| rotate_half = module.rotate_half |
| eager_attention_forward = module.eager_attention_forward |
| ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS |
| |
|
|
| if not APPLY_PE_SHIFT: |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| |
| |
| |
| |
|
|
| |
| |
| if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE): |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None } |
| else: |
| cache_kwargs = {} |
|
|
|
|
| if "kwargs" in locals(): |
| pass |
| elif "flash_attn_kwargs" in locals(): |
| kwargs = flash_attn_kwargs |
| else: |
| raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.") |
|
|
| if "input_ids" not in locals(): |
| if "input_ids" in kwargs: |
| input_ids = kwargs.get("input_ids", None) |
| else: |
| sepllm_kwargs = kwargs.get("sepllm_kwargs", None) |
| assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given." |
| input_ids = sepllm_kwargs.get("input_ids", None) |
| |
| assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`." |
|
|
| if "position_ids" not in locals(): |
| position_ids = kwargs.get("position_ids") |
| |
| assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`." |
| bsz, q_len, _ = hidden_states.size() |
|
|
| input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states ) |
|
|
| if APPLY_PE_SHIFT: |
| key_states, value_states, query_states = past_key_value.update( |
| key_states = key_states, |
| value_states = value_states, |
| query_states = query_states, |
| input_ids = input_ids, |
| layer_idx = self.layer_idx, |
| position_ids = position_ids, |
| PREFILLING_FLAG = q_len > 1, |
| cache_kwargs = cache_kwargs ) |
|
|
| else: |
| key_states, value_states = past_key_value.update( |
| key_states = key_states, |
| value_states = value_states, |
| input_ids = input_ids, |
| layer_idx = self.layer_idx, |
| position_ids = position_ids, |
| PREFILLING_FLAG = q_len > 1, |
| cache_kwargs = cache_kwargs ) |
| |
| seq_len = past_key_value.get_usable_length(self.layer_idx) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask[..., :seq_len] |
| |
|
|
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
|
|
|
|
|
|
|
|
| def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): |
| """Validates model kwargs for generation. Generate argument typos will also be caught here.""" |
| |
| if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: |
| raise ValueError( |
| f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " |
| "check the model documentation for supported cache formats." |
| ) |
|
|
| |
| if self.config.is_encoder_decoder: |
| for key in ["decoder_input_ids"]: |
| model_kwargs.pop(key, None) |
|
|
| unused_model_args = [] |
| model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) |
| |
| |
| if "kwargs" in model_args or "model_kwargs" in model_args: |
| model_args |= set(inspect.signature(self.forward).parameters) |
|
|
| |
| if self.config.is_encoder_decoder: |
| base_model = getattr(self, self.base_model_prefix, None) |
|
|
| |
| encoder = getattr(self, "encoder", None) |
| |
| |
| |
| if encoder is None and base_model is not None: |
| encoder = getattr(base_model, "encoder", None) |
|
|
| if encoder is not None: |
| encoder_model_args = set(inspect.signature(encoder.forward).parameters) |
| model_args |= encoder_model_args |
|
|
| |
| decoder = getattr(self, "decoder", None) |
| if decoder is None and base_model is not None: |
| decoder = getattr(base_model, "decoder", None) |
|
|
| if decoder is not None: |
| decoder_model_args = set(inspect.signature(decoder.forward).parameters) |
| model_args |= {f"decoder_{x}" for x in decoder_model_args} |
|
|
| for key, value in model_kwargs.items(): |
| |
| |
| |
| |
|
|
| |
| if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()): |
| unused_model_args.append(key) |
| |
|
|
| if unused_model_args: |
| raise ValueError( |
| f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" |
| " generate arguments will also show up in this list)" |
| ) |
|
|
|
|
|
|