Spaces:
Configuration error
Configuration error
| import math | |
| from typing import List, Tuple, Optional, Union | |
| import torch | |
| from torch import nn as nn | |
| def pixel_freq_bands( | |
| num_bands: int, | |
| max_freq: float = 224., | |
| linear_bands: bool = True, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ): | |
| if linear_bands: | |
| bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) | |
| else: | |
| bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) | |
| return bands * torch.pi | |
| def inv_freq_bands( | |
| num_bands: int, | |
| temperature: float = 100000., | |
| step: int = 2, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) | |
| return inv_freq | |
| def build_sincos2d_pos_embed( | |
| feat_shape: List[int], | |
| dim: int = 64, | |
| temperature: float = 10000., | |
| reverse_coord: bool = False, | |
| interleave_sin_cos: bool = False, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| feat_shape: | |
| dim: | |
| temperature: | |
| reverse_coord: stack grid order W, H instead of H, W | |
| interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos | |
| dtype: | |
| device: | |
| Returns: | |
| """ | |
| assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' | |
| pos_dim = dim // 4 | |
| bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) | |
| if reverse_coord: | |
| feat_shape = feat_shape[::-1] # stack W, H instead of H, W | |
| grid = torch.stack( | |
| torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) | |
| pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) | |
| # FIXME add support for unflattened spatial dim? | |
| stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos | |
| pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) | |
| return pos_emb | |
| def build_fourier_pos_embed( | |
| feat_shape: List[int], | |
| bands: Optional[torch.Tensor] = None, | |
| num_bands: int = 64, | |
| max_res: int = 224, | |
| linear_bands: bool = False, | |
| include_grid: bool = False, | |
| concat_out: bool = True, | |
| in_pixels: bool = True, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> List[torch.Tensor]: | |
| if bands is None: | |
| if in_pixels: | |
| bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) | |
| else: | |
| bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) | |
| else: | |
| if device is None: | |
| device = bands.device | |
| if dtype is None: | |
| dtype = bands.dtype | |
| if in_pixels: | |
| grid = torch.stack(torch.meshgrid( | |
| [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) | |
| else: | |
| grid = torch.stack(torch.meshgrid( | |
| [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) | |
| grid = grid.unsqueeze(-1) | |
| pos = grid * bands | |
| pos_sin, pos_cos = pos.sin(), pos.cos() | |
| out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) | |
| # FIXME torchscript doesn't like multiple return types, probably need to always cat? | |
| if concat_out: | |
| out = torch.cat(out, dim=-1) | |
| return out | |
| class FourierEmbed(nn.Module): | |
| def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): | |
| super().__init__() | |
| self.max_res = max_res | |
| self.num_bands = num_bands | |
| self.concat_grid = concat_grid | |
| self.keep_spatial = keep_spatial | |
| self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) | |
| def forward(self, x): | |
| B, C = x.shape[:2] | |
| feat_shape = x.shape[2:] | |
| emb = build_fourier_pos_embed( | |
| feat_shape, | |
| self.bands, | |
| include_grid=self.concat_grid, | |
| dtype=x.dtype, | |
| device=x.device) | |
| emb = emb.transpose(-1, -2).flatten(len(feat_shape)) | |
| batch_expand = (B,) + (-1,) * (x.ndim - 1) | |
| # FIXME support nD | |
| if self.keep_spatial: | |
| x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) | |
| else: | |
| x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) | |
| x = x.reshape(B, feat_shape.numel(), -1) | |
| return x | |
| def rot(x): | |
| return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) | |
| def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): | |
| return x * cos_emb + rot(x) * sin_emb | |
| def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): | |
| if isinstance(x, torch.Tensor): | |
| x = [x] | |
| return [t * cos_emb + rot(t) * sin_emb for t in x] | |
| def apply_rot_embed_split(x: torch.Tensor, emb): | |
| split = emb.shape[-1] // 2 | |
| return x * emb[:, :split] + rot(x) * emb[:, split:] | |
| def build_rotary_pos_embed( | |
| feat_shape: List[int], | |
| bands: Optional[torch.Tensor] = None, | |
| dim: int = 64, | |
| max_freq: float = 224, | |
| linear_bands: bool = False, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ): | |
| """ | |
| NOTE: shape arg should include spatial dim only | |
| """ | |
| feat_shape = torch.Size(feat_shape) | |
| sin_emb, cos_emb = build_fourier_pos_embed( | |
| feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands, | |
| concat_out=False, device=device, dtype=dtype) | |
| N = feat_shape.numel() | |
| sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) | |
| cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) | |
| return sin_emb, cos_emb | |
| class RotaryEmbedding(nn.Module): | |
| """ Rotary position embedding | |
| NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not | |
| been well tested, and will likely change. It will be moved to its own file. | |
| The following impl/resources were referenced for this impl: | |
| * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py | |
| * https://blog.eleuther.ai/rotary-embeddings/ | |
| """ | |
| def __init__(self, dim, max_res=224, linear_bands: bool = False): | |
| super().__init__() | |
| self.dim = dim | |
| self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) | |
| def get_embed(self, shape: List[int]): | |
| return build_rotary_pos_embed(shape, self.bands) | |
| def forward(self, x): | |
| # assuming channel-first tensor where spatial dim are >= 2 | |
| sin_emb, cos_emb = self.get_embed(x.shape[2:]) | |
| return apply_rot_embed(x, sin_emb, cos_emb) | |