Source code for torchvision_customizer.blocks.advanced_architecture

"""
Advanced Architecture Features Module

This module provides sophisticated neural network components for advanced architectures:
- Residual connections with per-layer configuration
- Multiple skip connection patterns
- Dense connections (DenseNet-style)
- Mixed architecture combinations

Author: torchvision-customizer
License: MIT
"""

from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision_customizer.blocks.conv_block import ConvBlock


class ResidualConnector(nn.Module):
    """
    Configurable residual connection module.
    
    Supports different skip connection patterns and projection strategies.
    Can be used to add residual connections to any two layers.
    
    Attributes:
        skip_type (str): Type of skip connection ('identity', 'projection', 'bottle')
        stride (int): Stride for the main path
        in_channels (int): Input channels
        out_channels (int): Output channels
        
    Examples:
        >>> # Identity skip (same dimensions)
        >>> rc = ResidualConnector('identity', in_channels=64, out_channels=64)
        >>> x = torch.randn(2, 64, 32, 32)
        >>> skip = torch.randn(2, 64, 32, 32)
        >>> out = rc(x, skip)
        
        >>> # Projection skip (different dimensions)
        >>> rc = ResidualConnector('projection', in_channels=64, out_channels=128, stride=2)
        >>> x = torch.randn(2, 128, 16, 16)
        >>> skip = torch.randn(2, 64, 32, 32)
        >>> out = rc(x, skip)
    """
    
    def __init__(
        self,
        skip_type: str = 'identity',
        in_channels: int = 64,
        out_channels: int = 64,
        stride: int = 1,
        activation: str = 'relu',
        use_batchnorm: bool = True,
    ):
        """
        Initialize ResidualConnector.
        
        Args:
            skip_type: Type of skip connection ('identity', 'projection', 'bottle')
            in_channels: Number of input channels for skip path
            out_channels: Number of output channels
            stride: Stride for skip path if needed
            activation: Activation function name
            use_batchnorm: Whether to use batch normalization
            
        Raises:
            ValueError: If skip_type is not recognized
        """
        super().__init__()
        
        valid_types = ['identity', 'projection', 'bottle']
        if skip_type not in valid_types:
            raise ValueError(f"skip_type must be one of {valid_types}, got {skip_type}")
        
        self.skip_type = skip_type
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.use_batchnorm = use_batchnorm
        
        # Get activation function
        from torchvision_customizer.layers.activations import get_activation
        self.activation = get_activation(activation)
        
        # Build skip path
        if skip_type == 'identity':
            self.skip = nn.Identity()
        elif skip_type == 'projection':
            self.skip = self._build_projection()
        elif skip_type == 'bottle':
            self.skip = self._build_bottle()
    
    def _build_projection(self) -> nn.Module:
        """Build projection skip connection."""
        layers = [
            nn.Conv2d(
                self.in_channels,
                self.out_channels,
                kernel_size=1,
                stride=self.stride,
                bias=not self.use_batchnorm
            )
        ]
        
        if self.use_batchnorm:
            layers.append(nn.BatchNorm2d(self.out_channels))
        
        return nn.Sequential(*layers)
    
    def _build_bottle(self) -> nn.Module:
        """Build bottleneck skip connection."""
        hidden_channels = max(1, self.out_channels // 4)
        layers = [
            nn.Conv2d(self.in_channels, hidden_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                hidden_channels,
                self.out_channels,
                kernel_size=1,
                stride=self.stride,
                bias=not self.use_batchnorm
            )
        ]
        
        if self.use_batchnorm:
            layers.append(nn.BatchNorm2d(self.out_channels))
        
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Main path output tensor
            skip: Skip connection input tensor
            
        Returns:
            Residual connection output with activation
        """
        skip_out = self.skip(skip)
        out = x + skip_out
        return self.activation(out)


class SkipConnectionBuilder(nn.Module):
    """
    Builder for different skip connection patterns.
    
    Supports various skip patterns:
    - dense: All previous layers connected
    - residual: Only previous layer
    - hierarchical: Skip connections at multiple scales
    - dense_residual: Combination of dense and residual
    
    Examples:
        >>> builder = SkipConnectionBuilder('residual', num_blocks=5, in_channels=64)
        >>> features = [torch.randn(2, 64, 32, 32) for _ in range(5)]
        >>> output = builder(features)
    """
    
    def __init__(
        self,
        pattern: str = 'residual',
        num_blocks: int = 4,
        in_channels: int = 64,
        activation: str = 'relu',
        use_batchnorm: bool = True,
    ):
        """
        Initialize SkipConnectionBuilder.
        
        Args:
            pattern: Type of skip pattern ('residual', 'dense', 'hierarchical', 'dense_residual')
            num_blocks: Number of blocks to connect
            in_channels: Input channels
            activation: Activation function name
            use_batchnorm: Whether to use batch normalization
            
        Raises:
            ValueError: If pattern is not recognized
        """
        super().__init__()
        
        valid_patterns = ['residual', 'dense', 'hierarchical', 'dense_residual']
        if pattern not in valid_patterns:
            raise ValueError(f"pattern must be one of {valid_patterns}, got {pattern}")
        
        self.pattern = pattern
        self.num_blocks = num_blocks
        self.in_channels = in_channels
        self.activation = activation
        self.use_batchnorm = use_batchnorm
        
        # Create projections for connecting features
        if pattern in ['dense', 'dense_residual']:
            self._create_dense_projections()
    
    def _create_dense_projections(self):
        """Create 1x1 convolutions for dense connections."""
        self.projections = nn.ModuleList()
        for i in range(self.num_blocks):
            # Each layer needs projections to all subsequent layers
            proj_dict = nn.ModuleDict()
            for j in range(i + 1, self.num_blocks):
                proj = nn.Sequential(
                    nn.Conv2d(self.in_channels, self.in_channels, 1, bias=False),
                    nn.BatchNorm2d(self.in_channels) if self.use_batchnorm else nn.Identity()
                )
                proj_dict[f'to_{j}'] = proj
            self.projections.append(proj_dict)
    
    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        """
        Apply skip connection pattern.
        
        Args:
            features: List of feature tensors from blocks
            
        Returns:
            Combined feature tensor
        """
        if self.pattern == 'residual':
            return self._residual_pattern(features)
        elif self.pattern == 'dense':
            return self._dense_pattern(features)
        elif self.pattern == 'hierarchical':
            return self._hierarchical_pattern(features)
        elif self.pattern == 'dense_residual':
            return self._dense_residual_pattern(features)
    
    def _residual_pattern(self, features: List[torch.Tensor]) -> torch.Tensor:
        """Residual: current + previous."""
        output = features[0]
        for i in range(1, len(features)):
            output = output + features[i]
        return output
    
    def _dense_pattern(self, features: List[torch.Tensor]) -> torch.Tensor:
        """Dense: concatenate all with projections."""
        outputs = [features[0]]
        for i in range(1, len(features)):
            # Project all previous features and concatenate
            projected = []
            for j in range(i):
                proj = self.projections[j][f'to_{i}']
                projected.append(proj(features[j]))
            combined = torch.cat(projected + [features[i]], dim=1)
            outputs.append(combined)
        return outputs[-1]
    
    def _hierarchical_pattern(self, features: List[torch.Tensor]) -> torch.Tensor:
        """Hierarchical: skip at multiple scales."""
        # Connect every 2nd layer
        output = features[0]
        for i in range(2, len(features), 2):
            if i < len(features):
                output = output + features[i]
        output = output + features[-1]
        return output
    
    def _dense_residual_pattern(self, features: List[torch.Tensor]) -> torch.Tensor:
        """Dense-Residual: combination of dense and residual."""
        # Residual within blocks, dense across stages
        output = features[0]
        for i in range(1, len(features)):
            # Local residual
            local_sum = output + features[i]
            # Global connection (every other layer)
            if i % 2 == 0:
                output = local_sum + features[0]
            else:
                output = local_sum
        return output


[docs] class DenseConnectionBlock(nn.Module): """ DenseNet-style dense connection block. All layers are connected to each other with learned transformations. Features are concatenated rather than added. Attributes: num_layers (int): Number of dense layers growth_rate (int): Number of new channels per layer bottleneck_ratio (int): Bottleneck reduction ratio Examples: >>> block = DenseConnectionBlock( ... num_layers=4, ... in_channels=64, ... growth_rate=32, ... kernel_size=3 ... ) >>> x = torch.randn(2, 64, 32, 32) >>> out = block(x) >>> # Output channels = 64 + (4 * 32) = 192 """ def __init__( self, num_layers: int = 4, in_channels: int = 64, growth_rate: int = 32, kernel_size: int = 3, bottleneck_ratio: int = 4, activation: str = 'relu', use_batchnorm: bool = True, dropout_rate: float = 0.0, ): """ Initialize DenseConnectionBlock. Args: num_layers: Number of dense layers in_channels: Input channels growth_rate: Number of new channels per layer kernel_size: Convolution kernel size bottleneck_ratio: Bottleneck reduction ratio activation: Activation function name use_batchnorm: Whether to use batch normalization dropout_rate: Dropout probability Raises: ValueError: If parameters are invalid """ super().__init__() if num_layers < 1: raise ValueError(f"num_layers must be >= 1, got {num_layers}") if growth_rate < 1: raise ValueError(f"growth_rate must be >= 1, got {growth_rate}") self.num_layers = num_layers self.in_channels = in_channels self.growth_rate = growth_rate self.dropout_rate = dropout_rate self.layers = nn.ModuleList() current_channels = in_channels for i in range(num_layers): layer = self._build_dense_layer( current_channels, growth_rate, kernel_size, bottleneck_ratio, activation, use_batchnorm, dropout_rate ) self.layers.append(layer) current_channels += growth_rate self.out_channels = current_channels def _build_dense_layer( self, in_channels: int, growth_rate: int, kernel_size: int, bottleneck_ratio: int, activation: str, use_batchnorm: bool, dropout_rate: float, ) -> nn.Module: """Build a single dense layer.""" # Bottleneck design: 1x1 conv reduces channels, 3x3 conv creates new features bottleneck_channels = bottleneck_ratio * growth_rate return nn.Sequential( nn.BatchNorm2d(in_channels) if use_batchnorm else nn.Identity(), nn.ReLU(inplace=True), nn.Conv2d( in_channels, bottleneck_channels, kernel_size=1, bias=not use_batchnorm ), nn.BatchNorm2d(bottleneck_channels) if use_batchnorm else nn.Identity(), nn.ReLU(inplace=True), nn.Conv2d( bottleneck_channels, growth_rate, kernel_size=kernel_size, padding=kernel_size // 2, bias=not use_batchnorm ), nn.Dropout2d(p=dropout_rate) if dropout_rate > 0 else nn.Identity(), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with dense connections. Args: x: Input tensor Returns: Output tensor with all features concatenated """ features = [x] for layer in self.layers: new_features = layer(torch.cat(features, dim=1)) features.append(new_features) return torch.cat(features, dim=1)
class MixedArchitectureBlock(nn.Module): """ Block combining multiple architecture patterns. Supports mixing residual, dense, and standard connections with optional attention mechanisms. Examples: >>> block = MixedArchitectureBlock( ... num_layers=4, ... in_channels=64, ... mixed_patterns=['residual', 'dense', 'residual'] ... ) >>> x = torch.randn(2, 64, 32, 32) >>> out = block(x) """ def __init__( self, num_layers: int = 4, in_channels: int = 64, out_channels: int = 64, mixed_patterns: Optional[List[str]] = None, kernel_size: int = 3, activation: str = 'relu', use_batchnorm: bool = True, dropout_rate: float = 0.0, ): """ Initialize MixedArchitectureBlock. Args: num_layers: Number of layers in_channels: Input channels out_channels: Output channels mixed_patterns: List of patterns per layer or None for auto kernel_size: Convolution kernel size activation: Activation function name use_batchnorm: Whether to use batch normalization dropout_rate: Dropout probability """ super().__init__() self.num_layers = num_layers self.in_channels = in_channels self.out_channels = out_channels # Default pattern cycling if mixed_patterns is None: mixed_patterns = ['standard', 'residual', 'dense', 'standard'] * ( (num_layers // 4) + 1 ) mixed_patterns = mixed_patterns[:num_layers] if len(mixed_patterns) != num_layers: raise ValueError( f"mixed_patterns length ({len(mixed_patterns)}) " f"must match num_layers ({num_layers})" ) self.patterns = mixed_patterns self.conv_blocks = nn.ModuleList() for i in range(num_layers): block = ConvBlock( in_channels=in_channels if i == 0 else in_channels, out_channels=in_channels, kernel_size=kernel_size, activation=activation, use_batchnorm=use_batchnorm, dropout_rate=dropout_rate, ) self.conv_blocks.append(block) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with mixed patterns.""" features = [x] for i, (block, pattern) in enumerate(zip(self.conv_blocks, self.patterns)): current = block(features[-1]) if pattern == 'residual' and i > 0: current = current + features[-1] elif pattern == 'dense' and i > 0: current = torch.cat([features[-1], current], dim=1) features.append(current) return features[-1] # Utility functions def create_skip_connections( features: List[torch.Tensor], pattern: str = 'residual', ) -> torch.Tensor: """ Create skip connections between features. Args: features: List of feature tensors pattern: Skip pattern type Returns: Combined feature tensor """ if pattern == 'residual': result = features[0] for feat in features[1:]: result = result + feat return result elif pattern == 'concatenate': return torch.cat(features, dim=1) else: raise ValueError(f"Unknown pattern: {pattern}") def validate_architecture_compatibility( features: List[torch.Tensor], skip_pattern: str, ) -> bool: """ Validate that features are compatible with skip pattern. Args: features: List of feature tensors skip_pattern: Skip pattern type Returns: True if compatible Raises: ValueError: If incompatible """ if len(features) == 0: raise ValueError("features list cannot be empty") if skip_pattern == 'residual': # All features must have same spatial and channel dimensions ref_shape = features[0].shape for feat in features[1:]: if feat.shape != ref_shape: raise ValueError( f"All features must have same shape for residual pattern. " f"Got {feat.shape} vs {ref_shape}" ) return True