Spaces:
Configuration error
Configuration error
| """ Norm Layer Factory | |
| Create norm modules by string (to mirror create_act and creat_norm-act fns) | |
| Copyright 2022 Ross Wightman | |
| """ | |
| import types | |
| import functools | |
| import torch.nn as nn | |
| from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d | |
| _NORM_MAP = dict( | |
| batchnorm=nn.BatchNorm2d, | |
| batchnorm2d=nn.BatchNorm2d, | |
| batchnorm1d=nn.BatchNorm1d, | |
| groupnorm=GroupNorm, | |
| groupnorm1=GroupNorm1, | |
| layernorm=LayerNorm, | |
| layernorm2d=LayerNorm2d, | |
| ) | |
| _NORM_TYPES = {m for n, m in _NORM_MAP.items()} | |
| def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): | |
| layer = get_norm_layer(layer_name, act_layer=act_layer) | |
| layer_instance = layer(num_features, apply_act=apply_act, **kwargs) | |
| return layer_instance | |
| def get_norm_layer(norm_layer): | |
| assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) | |
| norm_kwargs = {} | |
| # unbind partial fn, so args can be rebound later | |
| if isinstance(norm_layer, functools.partial): | |
| norm_kwargs.update(norm_layer.keywords) | |
| norm_layer = norm_layer.func | |
| if isinstance(norm_layer, str): | |
| layer_name = norm_layer.replace('_', '') | |
| norm_layer = _NORM_MAP.get(layer_name, None) | |
| elif norm_layer in _NORM_TYPES: | |
| norm_layer = norm_layer | |
| elif isinstance(norm_layer, types.FunctionType): | |
| # if function type, assume it is a lambda/fn that creates a norm layer | |
| norm_layer = norm_layer | |
| else: | |
| type_name = norm_layer.__name__.lower().replace('_', '') | |
| norm_layer = _NORM_MAP.get(type_name, None) | |
| assert norm_layer is not None, f"No equivalent norm layer for {type_name}" | |
| if norm_kwargs: | |
| norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args | |
| return norm_layer | |