Spaces:
Configuration error
Configuration error
| """ EvoNorm in PyTorch | |
| Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967 | |
| @inproceedings{NEURIPS2020, | |
| author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc}, | |
| booktitle = {Advances in Neural Information Processing Systems}, | |
| editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, | |
| pages = {13539--13550}, | |
| publisher = {Curran Associates, Inc.}, | |
| title = {Evolving Normalization-Activation Layers}, | |
| url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf}, | |
| volume = {33}, | |
| year = {2020} | |
| } | |
| An attempt at getting decent performing EvoNorms running in PyTorch. | |
| While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm | |
| in terms of memory usage and throughput on GPUs. | |
| I'm testing these modules on TPU w/ PyTorch XLA. Promising start but | |
| currently working around some issues with builtin torch/tensor.var/std. Unlike | |
| GPU, similar train speeds for EvoNormS variants and BatchNorm. | |
| Hacked together by / Copyright 2020 Ross Wightman | |
| """ | |
| from typing import Sequence, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .create_act import create_act_layer | |
| from .trace_utils import _assert | |
| def instance_std(x, eps: float = 1e-5): | |
| std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) | |
| return std.expand(x.shape) | |
| def instance_std_tpu(x, eps: float = 1e-5): | |
| std = manual_var(x, dim=(2, 3)).add(eps).sqrt() | |
| return std.expand(x.shape) | |
| # instance_std = instance_std_tpu | |
| def instance_rms(x, eps: float = 1e-5): | |
| rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype) | |
| return rms.expand(x.shape) | |
| def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): | |
| xm = x.mean(dim=dim, keepdim=True) | |
| if diff_sqm: | |
| # difference of squared mean and mean squared, faster on TPU can be less stable | |
| var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0) | |
| else: | |
| var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True) | |
| return var | |
| def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): | |
| B, C, H, W = x.shape | |
| x_dtype = x.dtype | |
| _assert(C % groups == 0, '') | |
| if flatten: | |
| x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues | |
| std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) | |
| else: | |
| x = x.reshape(B, groups, C // groups, H, W) | |
| std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) | |
| return std.expand(x.shape).reshape(B, C, H, W) | |
| def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): | |
| # This is a workaround for some stability / odd behaviour of .var and .std | |
| # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results | |
| B, C, H, W = x.shape | |
| _assert(C % groups == 0, '') | |
| if flatten: | |
| x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues | |
| var = manual_var(x, dim=-1, diff_sqm=diff_sqm) | |
| else: | |
| x = x.reshape(B, groups, C // groups, H, W) | |
| var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm) | |
| return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W) | |
| #group_std = group_std_tpu # FIXME TPU temporary | |
| def group_rms(x, groups: int = 32, eps: float = 1e-5): | |
| B, C, H, W = x.shape | |
| _assert(C % groups == 0, '') | |
| x_dtype = x.dtype | |
| x = x.reshape(B, groups, C // groups, H, W) | |
| rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype) | |
| return rms.expand(x.shape).reshape(B, C, H, W) | |
| class EvoNorm2dB0(nn.Module): | |
| def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_): | |
| super().__init__() | |
| self.apply_act = apply_act # apply activation (non-linearity) | |
| self.momentum = momentum | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None | |
| self.register_buffer('running_var', torch.ones(num_features)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| if self.v is not None: | |
| nn.init.ones_(self.v) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| if self.v is not None: | |
| if self.training: | |
| var = x.float().var(dim=(0, 2, 3), unbiased=False) | |
| # var = manual_var(x, dim=(0, 2, 3)).squeeze() | |
| n = x.numel() / x.shape[1] | |
| self.running_var.copy_( | |
| self.running_var * (1 - self.momentum) + | |
| var.detach() * self.momentum * (n / (n - 1))) | |
| else: | |
| var = self.running_var | |
| left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x) | |
| v = self.v.to(x_dtype).view(v_shape) | |
| right = x * v + instance_std(x, self.eps) | |
| x = x / left.max(right) | |
| return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape) | |
| class EvoNorm2dB1(nn.Module): | |
| def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): | |
| super().__init__() | |
| self.apply_act = apply_act # apply activation (non-linearity) | |
| self.momentum = momentum | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| self.register_buffer('running_var', torch.ones(num_features)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| if self.apply_act: | |
| if self.training: | |
| var = x.float().var(dim=(0, 2, 3), unbiased=False) | |
| n = x.numel() / x.shape[1] | |
| self.running_var.copy_( | |
| self.running_var * (1 - self.momentum) + | |
| var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) | |
| else: | |
| var = self.running_var | |
| var = var.to(x_dtype).view(v_shape) | |
| left = var.add(self.eps).sqrt_() | |
| right = (x + 1) * instance_rms(x, self.eps) | |
| x = x / left.max(right) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dB2(nn.Module): | |
| def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): | |
| super().__init__() | |
| self.apply_act = apply_act # apply activation (non-linearity) | |
| self.momentum = momentum | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| self.register_buffer('running_var', torch.ones(num_features)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| if self.apply_act: | |
| if self.training: | |
| var = x.float().var(dim=(0, 2, 3), unbiased=False) | |
| n = x.numel() / x.shape[1] | |
| self.running_var.copy_( | |
| self.running_var * (1 - self.momentum) + | |
| var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) | |
| else: | |
| var = self.running_var | |
| var = var.to(x_dtype).view(v_shape) | |
| left = var.add(self.eps).sqrt_() | |
| right = instance_rms(x, self.eps) - x | |
| x = x / left.max(right) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dS0(nn.Module): | |
| def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): | |
| super().__init__() | |
| self.apply_act = apply_act # apply activation (non-linearity) | |
| if group_size: | |
| assert num_features % group_size == 0 | |
| self.groups = num_features // group_size | |
| else: | |
| self.groups = groups | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| if self.v is not None: | |
| nn.init.ones_(self.v) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| if self.v is not None: | |
| v = self.v.view(v_shape).to(x_dtype) | |
| x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dS0a(EvoNorm2dS0): | |
| def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_): | |
| super().__init__( | |
| num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| d = group_std(x, self.groups, self.eps) | |
| if self.v is not None: | |
| v = self.v.view(v_shape).to(x_dtype) | |
| x = x * (x * v).sigmoid() | |
| x = x / d | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dS1(nn.Module): | |
| def __init__( | |
| self, num_features, groups=32, group_size=None, | |
| apply_act=True, act_layer=None, eps=1e-5, **_): | |
| super().__init__() | |
| act_layer = act_layer or nn.SiLU | |
| self.apply_act = apply_act # apply activation (non-linearity) | |
| if act_layer is not None and apply_act: | |
| self.act = create_act_layer(act_layer) | |
| else: | |
| self.act = nn.Identity() | |
| if group_size: | |
| assert num_features % group_size == 0 | |
| self.groups = num_features // group_size | |
| else: | |
| self.groups = groups | |
| self.eps = eps | |
| self.pre_act_norm = False | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| if self.apply_act: | |
| x = self.act(x) / group_std(x, self.groups, self.eps) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dS1a(EvoNorm2dS1): | |
| def __init__( | |
| self, num_features, groups=32, group_size=None, | |
| apply_act=True, act_layer=None, eps=1e-3, **_): | |
| super().__init__( | |
| num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| x = self.act(x) / group_std(x, self.groups, self.eps) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dS2(nn.Module): | |
| def __init__( | |
| self, num_features, groups=32, group_size=None, | |
| apply_act=True, act_layer=None, eps=1e-5, **_): | |
| super().__init__() | |
| act_layer = act_layer or nn.SiLU | |
| self.apply_act = apply_act # apply activation (non-linearity) | |
| if act_layer is not None and apply_act: | |
| self.act = create_act_layer(act_layer) | |
| else: | |
| self.act = nn.Identity() | |
| if group_size: | |
| assert num_features % group_size == 0 | |
| self.groups = num_features // group_size | |
| else: | |
| self.groups = groups | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| if self.apply_act: | |
| x = self.act(x) / group_rms(x, self.groups, self.eps) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |
| class EvoNorm2dS2a(EvoNorm2dS2): | |
| def __init__( | |
| self, num_features, groups=32, group_size=None, | |
| apply_act=True, act_layer=None, eps=1e-3, **_): | |
| super().__init__( | |
| num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) | |
| def forward(self, x): | |
| _assert(x.dim() == 4, 'expected 4D input') | |
| x_dtype = x.dtype | |
| v_shape = (1, -1, 1, 1) | |
| x = self.act(x) / group_rms(x, self.groups, self.eps) | |
| return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) | |