Spaces:
Configuration error
Configuration error
| """ Normalization + Activation Layers | |
| Provides Norm+Act fns for standard PyTorch norm layers such as | |
| * BatchNorm | |
| * GroupNorm | |
| * LayerNorm | |
| This allows swapping with alternative layers that are natively both norm + act such as | |
| * EvoNorm (evo_norm.py) | |
| * FilterResponseNorm (filter_response_norm.py) | |
| * InplaceABN (inplace_abn.py) | |
| Hacked together by / Copyright 2022 Ross Wightman | |
| """ | |
| from typing import Union, List, Optional, Any | |
| import torch | |
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| from .create_act import get_act_layer | |
| from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm | |
| from .trace_utils import _assert | |
| class BatchNormAct2d(nn.BatchNorm2d): | |
| """BatchNorm + Activation | |
| This module performs BatchNorm + Activation in a manner that will remain backwards | |
| compatible with weights trained with separate bn, act. This is why we inherit from BN | |
| instead of composing it as a .bn member. | |
| """ | |
| def __init__( | |
| self, | |
| num_features, | |
| eps=1e-5, | |
| momentum=0.1, | |
| affine=True, | |
| track_running_stats=True, | |
| apply_act=True, | |
| act_layer=nn.ReLU, | |
| inplace=True, | |
| drop_layer=None, | |
| device=None, | |
| dtype=None | |
| ): | |
| try: | |
| factory_kwargs = {'device': device, 'dtype': dtype} | |
| super(BatchNormAct2d, self).__init__( | |
| num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, | |
| **factory_kwargs | |
| ) | |
| except TypeError: | |
| # NOTE for backwards compat with old PyTorch w/o factory device/dtype support | |
| super(BatchNormAct2d, self).__init__( | |
| num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) | |
| self.drop = drop_layer() if drop_layer is not None else nn.Identity() | |
| act_layer = get_act_layer(act_layer) # string -> nn.Module | |
| if act_layer is not None and apply_act: | |
| act_args = dict(inplace=True) if inplace else {} | |
| self.act = act_layer(**act_args) | |
| else: | |
| self.act = nn.Identity() | |
| def forward(self, x): | |
| # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing | |
| _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)') | |
| # exponential_average_factor is set to self.momentum | |
| # (when it is available) only so that it gets updated | |
| # in ONNX graph when this node is exported to ONNX. | |
| if self.momentum is None: | |
| exponential_average_factor = 0.0 | |
| else: | |
| exponential_average_factor = self.momentum | |
| if self.training and self.track_running_stats: | |
| # TODO: if statement only here to tell the jit to skip emitting this when it is None | |
| if self.num_batches_tracked is not None: # type: ignore[has-type] | |
| self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] | |
| if self.momentum is None: # use cumulative moving average | |
| exponential_average_factor = 1.0 / float(self.num_batches_tracked) | |
| else: # use exponential moving average | |
| exponential_average_factor = self.momentum | |
| r""" | |
| Decide whether the mini-batch stats should be used for normalization rather than the buffers. | |
| Mini-batch stats are used in training mode, and in eval mode when buffers are None. | |
| """ | |
| if self.training: | |
| bn_training = True | |
| else: | |
| bn_training = (self.running_mean is None) and (self.running_var is None) | |
| r""" | |
| Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be | |
| passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are | |
| used for normalization (i.e. in eval mode when buffers are not None). | |
| """ | |
| x = F.batch_norm( | |
| x, | |
| # If buffers are not to be tracked, ensure that they won't be updated | |
| self.running_mean if not self.training or self.track_running_stats else None, | |
| self.running_var if not self.training or self.track_running_stats else None, | |
| self.weight, | |
| self.bias, | |
| bn_training, | |
| exponential_average_factor, | |
| self.eps, | |
| ) | |
| x = self.drop(x) | |
| x = self.act(x) | |
| return x | |
| class SyncBatchNormAct(nn.SyncBatchNorm): | |
| # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) | |
| # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers | |
| # but ONLY when used in conjunction with the timm conversion function below. | |
| # Do not create this module directly or use the PyTorch conversion function. | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine | |
| if hasattr(self, "drop"): | |
| x = self.drop(x) | |
| if hasattr(self, "act"): | |
| x = self.act(x) | |
| return x | |
| def convert_sync_batchnorm(module, process_group=None): | |
| # convert both BatchNorm and BatchNormAct layers to Synchronized variants | |
| module_output = module | |
| if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): | |
| if isinstance(module, BatchNormAct2d): | |
| # convert timm norm + act layer | |
| module_output = SyncBatchNormAct( | |
| module.num_features, | |
| module.eps, | |
| module.momentum, | |
| module.affine, | |
| module.track_running_stats, | |
| process_group=process_group, | |
| ) | |
| # set act and drop attr from the original module | |
| module_output.act = module.act | |
| module_output.drop = module.drop | |
| else: | |
| # convert standard BatchNorm layers | |
| module_output = torch.nn.SyncBatchNorm( | |
| module.num_features, | |
| module.eps, | |
| module.momentum, | |
| module.affine, | |
| module.track_running_stats, | |
| process_group, | |
| ) | |
| if module.affine: | |
| with torch.no_grad(): | |
| module_output.weight = module.weight | |
| module_output.bias = module.bias | |
| module_output.running_mean = module.running_mean | |
| module_output.running_var = module.running_var | |
| module_output.num_batches_tracked = module.num_batches_tracked | |
| if hasattr(module, "qconfig"): | |
| module_output.qconfig = module.qconfig | |
| for name, child in module.named_children(): | |
| module_output.add_module(name, convert_sync_batchnorm(child, process_group)) | |
| del module | |
| return module_output | |
| def _num_groups(num_channels, num_groups, group_size): | |
| if group_size: | |
| assert num_channels % group_size == 0 | |
| return num_channels // group_size | |
| return num_groups | |
| class GroupNormAct(nn.GroupNorm): | |
| # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args | |
| def __init__( | |
| self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, | |
| apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): | |
| super(GroupNormAct, self).__init__( | |
| _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) | |
| self.drop = drop_layer() if drop_layer is not None else nn.Identity() | |
| act_layer = get_act_layer(act_layer) # string -> nn.Module | |
| if act_layer is not None and apply_act: | |
| act_args = dict(inplace=True) if inplace else {} | |
| self.act = act_layer(**act_args) | |
| else: | |
| self.act = nn.Identity() | |
| self._fast_norm = is_fast_norm() | |
| def forward(self, x): | |
| if self._fast_norm: | |
| x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) | |
| else: | |
| x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) | |
| x = self.drop(x) | |
| x = self.act(x) | |
| return x | |
| class LayerNormAct(nn.LayerNorm): | |
| def __init__( | |
| self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, | |
| apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): | |
| super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) | |
| self.drop = drop_layer() if drop_layer is not None else nn.Identity() | |
| act_layer = get_act_layer(act_layer) # string -> nn.Module | |
| if act_layer is not None and apply_act: | |
| act_args = dict(inplace=True) if inplace else {} | |
| self.act = act_layer(**act_args) | |
| else: | |
| self.act = nn.Identity() | |
| self._fast_norm = is_fast_norm() | |
| def forward(self, x): | |
| if self._fast_norm: | |
| x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| else: | |
| x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| x = self.drop(x) | |
| x = self.act(x) | |
| return x | |
| class LayerNormAct2d(nn.LayerNorm): | |
| def __init__( | |
| self, num_channels, eps=1e-5, affine=True, | |
| apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): | |
| super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) | |
| self.drop = drop_layer() if drop_layer is not None else nn.Identity() | |
| act_layer = get_act_layer(act_layer) # string -> nn.Module | |
| if act_layer is not None and apply_act: | |
| act_args = dict(inplace=True) if inplace else {} | |
| self.act = act_layer(**act_args) | |
| else: | |
| self.act = nn.Identity() | |
| self._fast_norm = is_fast_norm() | |
| def forward(self, x): | |
| x = x.permute(0, 2, 3, 1) | |
| if self._fast_norm: | |
| x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| else: | |
| x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.drop(x) | |
| x = self.act(x) | |
| return x | |