Source code for torchvision_customizer.layers.normalizations

"""Normalization layer utilities for torchvision-customizer.

This module provides flexible normalization options for building neural networks,
supporting multiple normalization techniques including BatchNorm, GroupNorm,
LayerNorm, InstanceNorm, and custom variants.

Supported Normalization Types:
    - batch: BatchNorm2d (default)
    - group: GroupNorm with configurable groups
    - instance: InstanceNorm2d
    - layer: LayerNorm
    - none: Identity (no normalization)

Example:
    >>> from torchvision_customizer.layers import get_normalization
    >>> norm = get_normalization('batch', num_channels=64)
    >>> norm_group = get_normalization('group', num_channels=64, num_groups=32)
"""

import torch.nn as nn
from typing import Dict, Optional, Type, Any


# Registry of available normalization functions
NORMALIZATION_REGISTRY: Dict[str, Type[nn.Module]] = {
    'batch': nn.BatchNorm2d,
    'group': nn.GroupNorm,
    'instance': nn.InstanceNorm2d,
    'layer': nn.LayerNorm,
    'none': nn.Identity,
}

# Default parameters for each normalization type
NORMALIZATION_DEFAULTS: Dict[str, Dict[str, Any]] = {
    'batch': {},
    'group': {'num_groups': 32},
    'instance': {},
    'layer': {},
    'none': {},
}


[docs] def get_normalization( norm_type: str, num_channels: int, **kwargs ) -> nn.Module: """Create a normalization layer by type. Factory function that returns a configured normalization module based on the provided type. Supports multiple normalization techniques with automatic parameter configuration. Args: norm_type: Type of normalization layer. Case-insensitive. Supported values: 'batch', 'group', 'instance', 'layer', 'none' num_channels: Number of channels for the layer. **kwargs: Additional keyword arguments to pass to the normalization layer constructor. For GroupNorm, 'num_groups' can be specified. Returns: An instantiated nn.Module normalization layer. Raises: ValueError: If the normalization type is not supported. TypeError: If invalid keyword arguments are provided. Examples: >>> # BatchNorm (default) >>> norm = get_normalization('batch', num_channels=64) >>> # GroupNorm with custom groups >>> norm = get_normalization('group', num_channels=64, num_groups=32) >>> # InstanceNorm >>> norm = get_normalization('instance', num_channels=64) >>> # No normalization >>> norm = get_normalization('none', num_channels=64) """ normalized_type = norm_type.lower().strip() if normalized_type not in NORMALIZATION_REGISTRY: supported = ', '.join(sorted(NORMALIZATION_REGISTRY.keys())) raise ValueError( f"Unsupported normalization type: '{norm_type}'\n" f"Supported types: {supported}" ) norm_class = NORMALIZATION_REGISTRY[normalized_type] # Get default parameters default_params = NORMALIZATION_DEFAULTS[normalized_type].copy() # Handle special cases if normalized_type == 'group': default_params['num_channels'] = num_channels elif normalized_type in ['batch', 'instance']: default_params['num_features'] = num_channels elif normalized_type == 'layer': # LayerNorm expects normalized_shape, not num_channels default_params['normalized_shape'] = (num_channels,) elif normalized_type == 'none': # Identity doesn't take any arguments return nn.Identity() # Override with provided keyword arguments default_params.update(kwargs) try: return norm_class(**default_params) except TypeError as e: raise TypeError( f"Invalid parameters for {normalized_type}: {str(e)}" ) from e
[docs] def is_normalization_supported(norm_type: str) -> bool: """Check if a normalization type is supported. Args: norm_type: Type of normalization to check. Returns: True if the normalization type is supported, False otherwise. Example: >>> is_normalization_supported('batch') True >>> is_normalization_supported('unsupported') False """ return norm_type.lower().strip() in NORMALIZATION_REGISTRY
[docs] def get_supported_normalizations() -> list[str]: """Get list of all supported normalization types. Returns: A sorted list of supported normalization type names. Example: >>> normalizations = get_supported_normalizations() >>> print(normalizations) ['batch', 'group', 'instance', 'layer', 'none'] """ return sorted(NORMALIZATION_REGISTRY.keys())
[docs] class NormalizationFactory: """Factory class for creating and managing normalization layers. Provides a stateful interface for creating normalization layers with configuration management. Example: >>> factory = NormalizationFactory() >>> norm = factory.create('batch', num_channels=64) >>> group_norm = factory.create('group', num_channels=64, num_groups=16) """
[docs] @staticmethod def create(norm_type: str, num_channels: int, **kwargs) -> nn.Module: """Create a normalization layer. Args: norm_type: Type of normalization layer. num_channels: Number of channels. **kwargs: Additional keyword arguments for the normalization. Returns: The created normalization layer. Raises: ValueError: If normalization type is not supported. """ return get_normalization(norm_type, num_channels, **kwargs)
[docs] @staticmethod def is_supported(norm_type: str) -> bool: """Check if a normalization type is supported. Args: norm_type: Type of normalization to check. Returns: True if supported, False otherwise. """ return is_normalization_supported(norm_type)
[docs] @staticmethod def supported_normalizations() -> list[str]: """Get list of supported normalization types. Returns: List of supported normalization type names. """ return get_supported_normalizations()
[docs] @staticmethod def get_defaults(norm_type: str) -> Dict[str, Any]: """Get default parameters for a normalization type. Args: norm_type: Type of normalization. Returns: Dictionary of default parameters. Raises: ValueError: If normalization type is not supported. """ normalized_type = norm_type.lower().strip() if normalized_type not in NORMALIZATION_REGISTRY: supported = ', '.join(sorted(NORMALIZATION_REGISTRY.keys())) raise ValueError( f"Unsupported normalization type: '{norm_type}'\n" f"Supported types: {supported}" ) return NORMALIZATION_DEFAULTS[normalized_type].copy()