Source code for torchvision_customizer.registry.component_registry

"""Component Registry implementation.

Provides a centralized registry for all neural network components,
enabling discovery, inspection, and instantiation.
"""

from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch.nn as nn


class ComponentInfo:
    """Information about a registered component."""
    
    def __init__(
        self,
        name: str,
        cls: Type[nn.Module],
        category: str,
        description: str = "",
        aliases: List[str] = None,
    ):
        self.name = name
        self.cls = cls
        self.category = category
        self.description = description or cls.__doc__ or ""
        self.aliases = aliases or []
    
    def __repr__(self) -> str:
        return f"ComponentInfo(name='{self.name}', category='{self.category}')"
    
    def to_dict(self) -> dict:
        """Convert to dictionary for display."""
        import inspect
        sig = inspect.signature(self.cls.__init__)
        params = {
            name: str(param.annotation) if param.annotation != inspect.Parameter.empty else 'Any'
            for name, param in sig.parameters.items()
            if name != 'self'
        }
        
        return {
            'name': self.name,
            'category': self.category,
            'description': self.description.split('\n')[0] if self.description else "",
            'parameters': params,
            'aliases': self.aliases,
        }


[docs] class ComponentRegistry: """Central registry for all neural network components. The registry allows: - Component discovery via list() and categories() - Component inspection via info() - Component instantiation via get() - Custom component registration via register() Example: >>> registry = ComponentRegistry() >>> registry.list() ['conv', 'residual', 'se', ...] >>> Conv = registry.get('conv') >>> block = Conv(64, 128) """ def __init__(self): self._components: Dict[str, ComponentInfo] = {} self._aliases: Dict[str, str] = {} self._categories: Dict[str, List[str]] = {} self._initialized = False def _ensure_initialized(self): """Lazy initialization - register all built-in components.""" if self._initialized: return self._initialized = True self._register_builtin_components() def _register_builtin_components(self): """Register all built-in components from blocks and layers.""" # Import blocks from torchvision_customizer.blocks import ( ConvBlock, ResidualBlock, SEBlock, DepthwiseBlock, InceptionModule, Conv3DBlock, SuperConv2d, SuperLinear, DenseConnectionBlock, StandardBottleneck, WideBottleneck, GroupedBottleneck, ResidualStage, ResidualSequence, # v2.1: Advanced blocks CBAMBlock, ECABlock, DropPath, Mish, GeM, CoordConv, SqueezeExcitation, LayerScale, ConvBNAct, MBConv, FusedMBConv, ) # Import layers from torchvision_customizer.layers import ( ChannelAttention, SpatialAttention, ChannelSpatialAttention, MultiHeadAttention, AttentionBlock, ) # Register blocks self._register_component('conv', ConvBlock, 'block', aliases=['conv2d', 'convblock']) self._register_component('conv3d', Conv3DBlock, 'block') self._register_component('residual', ResidualBlock, 'block', aliases=['res', 'resblock']) self._register_component('se', SEBlock, 'block', aliases=['squeeze_excitation', 'seblock']) self._register_component('depthwise', DepthwiseBlock, 'block', aliases=['dw', 'depthwise_separable']) self._register_component('inception', InceptionModule, 'block', aliases=['inception_module']) self._register_component('dense', DenseConnectionBlock, 'block', aliases=['dense_block', 'densenet']) self._register_component('super_conv', SuperConv2d, 'block') self._register_component('super_linear', SuperLinear, 'block') # Register bottleneck variants self._register_component('bottleneck', StandardBottleneck, 'bottleneck', aliases=['standard_bottleneck']) self._register_component('wide_bottleneck', WideBottleneck, 'bottleneck') self._register_component('grouped_bottleneck', GroupedBottleneck, 'bottleneck') # Register architecture components self._register_component('residual_stage', ResidualStage, 'architecture') self._register_component('residual_sequence', ResidualSequence, 'architecture') # Register attention components self._register_component('channel_attention', ChannelAttention, 'attention', aliases=['ca', 'channel']) self._register_component('spatial_attention', SpatialAttention, 'attention', aliases=['sa', 'spatial']) self._register_component('cbam', ChannelSpatialAttention, 'attention', aliases=['channel_spatial']) self._register_component('multihead_attention', MultiHeadAttention, 'attention', aliases=['mha', 'multihead']) self._register_component('attention_block', AttentionBlock, 'attention') # v2.1: Register advanced blocks self._register_component('cbam_block', CBAMBlock, 'attention', aliases=['CBAMBlock', 'cbam_attention']) self._register_component('eca', ECABlock, 'attention', aliases=['eca_block', 'ECABlock', 'efficient_channel']) self._register_component('se_enhanced', SqueezeExcitation, 'attention', aliases=['squeeze_excite', 'se_v2']) # v2.1: Register regularization blocks self._register_component('drop_path', DropPath, 'regularization', aliases=['stochastic_depth', 'DropPath']) self._register_component('layer_scale', LayerScale, 'regularization') # v2.1: Register activation blocks self._register_component('mish', Mish, 'activation', aliases=['Mish']) # v2.1: Register pooling blocks self._register_component('gem', GeM, 'pooling', aliases=['generalized_mean', 'GeM']) # v2.1: Register conv blocks self._register_component('coord_conv', CoordConv, 'block', aliases=['CoordConv', 'coordinate_conv']) self._register_component('conv_bn_act', ConvBNAct, 'block', aliases=['ConvBNAct', 'convbnact']) self._register_component('mbconv', MBConv, 'block', aliases=['MBConv', 'inverted_residual']) self._register_component('fused_mbconv', FusedMBConv, 'block', aliases=['FusedMBConv', 'fused_inverted_residual']) def _register_component( self, name: str, cls: Type[nn.Module], category: str, description: str = "", aliases: List[str] = None, ): """Internal method to register a component.""" info = ComponentInfo(name, cls, category, description, aliases) self._components[name] = info # Register in category if category not in self._categories: self._categories[category] = [] self._categories[category].append(name) # Register aliases if aliases: for alias in aliases: self._aliases[alias] = name
[docs] def register(self, name: str, category: str = 'custom', aliases: List[str] = None): """Decorator to register a custom component. Example: >>> @registry.register('my_block', category='block') ... class MyBlock(nn.Module): ... def __init__(self, channels): ... super().__init__() ... self.conv = nn.Conv2d(channels, channels, 3, padding=1) """ def decorator(cls: Type[nn.Module]) -> Type[nn.Module]: self._ensure_initialized() self._register_component(name, cls, category, aliases=aliases) return cls return decorator
[docs] def get(self, name: str, **kwargs) -> Union[Type[nn.Module], nn.Module]: """Get a component class by name, or instantiate if kwargs provided. Args: name: Component name or alias **kwargs: If provided, instantiate the component with these args Returns: Component class if no kwargs, instantiated module if kwargs provided Example: >>> Conv = registry.get('conv') # Get class >>> block = registry.get('conv', in_channels=64, out_channels=128) # Get instance """ self._ensure_initialized() # Resolve alias if name in self._aliases: name = self._aliases[name] if name not in self._components: available = ', '.join(sorted(self._components.keys())) raise ValueError(f"Unknown component '{name}'. Available: {available}") cls = self._components[name].cls if kwargs: return cls(**kwargs) return cls
[docs] def list(self, category: str = None) -> List[str]: """List all registered component names. Args: category: Optional category filter Returns: List of component names """ self._ensure_initialized() if category: if category not in self._categories: return [] return sorted(self._categories[category]) return sorted(self._components.keys())
[docs] def categories(self) -> List[str]: """List all available categories.""" self._ensure_initialized() return sorted(self._categories.keys())
[docs] def info(self, name: str) -> dict: """Get detailed information about a component. Args: name: Component name or alias Returns: Dictionary with component information """ self._ensure_initialized() # Resolve alias if name in self._aliases: name = self._aliases[name] if name not in self._components: raise ValueError(f"Unknown component '{name}'") return self._components[name].to_dict()
def __repr__(self) -> str: self._ensure_initialized() return f"ComponentRegistry({len(self._components)} components in {len(self._categories)} categories)"
# Create default instance for module-level access _default_registry = ComponentRegistry() def register(name: str, category: str = 'custom', aliases: List[str] = None): """Decorator to register a component with the default registry.""" return _default_registry.register(name, category, aliases) def get(name: str, **kwargs): """Get component from default registry.""" return _default_registry.get(name, **kwargs) def list_components(category: str = None) -> List[str]: """List components from default registry.""" return _default_registry.list(category) def info(name: str) -> dict: """Get component info from default registry.""" return _default_registry.info(name) def categories() -> List[str]: """List categories from default registry.""" return _default_registry.categories()