import torch from transformers.cache_utils import DynamicCache from typing import Optional, List, Tuple, Dict, Any class DiffusionDynamicCache(DynamicCache): def __init__(self, num_hidden_layers: Optional[int] = None): super().__init__(num_hidden_layers) def full_update( self, new_kv: Tuple, cache_kwargs: Optional[Dict[str, Any]] = None, ): for i, (key, val) in enumerate(new_kv): self.key_cache[i] = torch.cat([self.key_cache[i], key], dim=-2) self.value_cache[i] = torch.cat([self.value_cache[i], val], dim=-2) def select_partial( self, indices: torch.Tensor, ): for i in range(len(self.key_cache)): self.key_cache[i] = self.key_cache[i][:, :, indices, :] self.value_cache[i] = self.value_cache[i][:, :, indices, :] def batch_select_minibatch(self, indices: torch.Tensor): """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" for layer_idx in range(len(self)): self.key_cache[layer_idx] = self.key_cache[layer_idx][:indices, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][:indices, ...]