Source code for torchvision_customizer.templates.densenet

"""DenseNet Template: Parametric DenseNet implementation.

Implements DenseNet variants based on the layers parameter:
- layers=121: DenseNet-121 (6, 12, 24, 16 blocks)
- layers=169: DenseNet-169 (6, 12, 32, 32 blocks)
- layers=201: DenseNet-201 (6, 12, 48, 32 blocks)
- layers=264: DenseNet-264 (6, 12, 64, 48 blocks)

All implementations are from scratch using the package's building blocks.
"""

from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision_customizer.templates.base import Template
from torchvision_customizer.layers import get_activation, get_normalization


# DenseNet configurations: layers -> block_counts
DENSENET_CONFIGS = {
    121: [6, 12, 24, 16],
    169: [6, 12, 32, 32],
    201: [6, 12, 48, 32],
    264: [6, 12, 64, 48],
}


class DenseLayer(nn.Module):
    """Single dense layer (BN -> ReLU -> Conv)."""
    
    def __init__(
        self,
        in_channels: int,
        growth_rate: int,
        bn_size: int = 4,
        activation: str = 'relu',
        dropout: float = 0.0,
    ):
        super().__init__()
        
        mid_channels = bn_size * growth_rate
        
        self.norm1 = nn.BatchNorm2d(in_channels)
        self.act1 = get_activation(activation)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False)
        
        self.norm2 = nn.BatchNorm2d(mid_channels)
        self.act2 = get_activation(activation)
        self.conv2 = nn.Conv2d(mid_channels, growth_rate, 3, padding=1, bias=False)
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(self.act1(self.norm1(x)))
        out = self.conv2(self.act2(self.norm2(out)))
        if self.dropout is not None:
            out = self.dropout(out)
        return torch.cat([x, out], dim=1)


class DenseBlock(nn.Module):
    """Dense block containing multiple dense layers."""
    
    def __init__(
        self,
        num_layers: int,
        in_channels: int,
        growth_rate: int,
        bn_size: int = 4,
        activation: str = 'relu',
        dropout: float = 0.0,
    ):
        super().__init__()
        
        layers = []
        for i in range(num_layers):
            layers.append(DenseLayer(
                in_channels + i * growth_rate,
                growth_rate,
                bn_size,
                activation,
                dropout,
            ))
        
        self.layers = nn.Sequential(*layers)
        self.out_channels = in_channels + num_layers * growth_rate
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


class Transition(nn.Module):
    """Transition layer between dense blocks."""
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        activation: str = 'relu',
    ):
        super().__init__()
        
        self.norm = nn.BatchNorm2d(in_channels)
        self.act = get_activation(activation)
        self.conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.pool = nn.AvgPool2d(2)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(self.act(self.norm(x)))
        x = self.pool(x)
        return x


class DenseNetModel(nn.Module):
    """Complete DenseNet model built from blocks."""
    
    def __init__(
        self,
        layers: int,
        num_classes: int = 1000,
        growth_rate: int = 32,
        bn_size: int = 4,
        compression: float = 0.5,
        activation: str = 'relu',
        dropout: float = 0.0,
        num_init_features: int = 64,
    ):
        super().__init__()
        
        if layers not in DENSENET_CONFIGS:
            raise ValueError(f"Unsupported layers={layers}. Choose from {list(DENSENET_CONFIGS.keys())}")
        
        block_counts = DENSENET_CONFIGS[layers]
        self.layers_config = layers
        self.growth_rate = growth_rate
        
        # Initial convolution
        self.features = nn.Sequential()
        self.features.add_module('conv0', nn.Conv2d(3, num_init_features, 7, stride=2, padding=3, bias=False))
        self.features.add_module('norm0', nn.BatchNorm2d(num_init_features))
        self.features.add_module('relu0', get_activation(activation))
        self.features.add_module('pool0', nn.MaxPool2d(3, stride=2, padding=1))
        
        # Dense blocks + transitions
        num_features = num_init_features
        for i, num_layers in enumerate(block_counts):
            block = DenseBlock(
                num_layers=num_layers,
                in_channels=num_features,
                growth_rate=growth_rate,
                bn_size=bn_size,
                activation=activation,
                dropout=dropout,
            )
            self.features.add_module(f'denseblock{i+1}', block)
            num_features = block.out_channels
            
            # Add transition (except after last block)
            if i != len(block_counts) - 1:
                out_features = int(num_features * compression)
                trans = Transition(num_features, out_features, activation)
                self.features.add_module(f'transition{i+1}', trans)
                num_features = out_features
        
        # Final batch norm
        self.features.add_module('norm_final', nn.BatchNorm2d(num_features))
        self.features.add_module('relu_final', get_activation(activation))
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(num_features, num_classes)
        self.num_features = num_features
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x
    
    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())
    
    def explain(self) -> str:
        block_counts = DENSENET_CONFIGS[self.layers_config]
        
        lines = [
            "+" + "-" * 60 + "+",
            "|" + f" DenseNet-{self.layers_config} ".center(60, "-") + "|",
            "+" + "-" * 60 + "+",
            f"| Stem: Conv(7x7, s2) -> BN -> ReLU -> MaxPool(3x3, s2)".ljust(61) + "|",
            "|" + " " * 60 + "|",
        ]
        
        for i, count in enumerate(block_counts):
            lines.append(f"| DenseBlock {i+1}: {count} layers (growth={self.growth_rate})".ljust(61) + "|")
            if i < len(block_counts) - 1:
                lines.append(f"|   -> Transition (compress + avgpool)".ljust(61) + "|")
        
        lines.extend([
            "|" + " " * 60 + "|",
            f"| Head: BN -> ReLU -> AdaptiveAvgPool -> Linear".ljust(61) + "|",
            "+" + "-" * 60 + "+",
            f"| Parameters: {self.count_parameters():,}".ljust(61) + "|",
            "+" + "-" * 60 + "+",
        ])
        
        return "\n".join(lines)


[docs] @Template.register('densenet') class DenseNetTemplate(Template): """DenseNet architecture template.""" def __init__(self, layers: int = 121, **kwargs): if layers not in DENSENET_CONFIGS: raise ValueError(f"layers must be one of {list(DENSENET_CONFIGS.keys())}, got {layers}") config = { 'name': f'DenseNet-{layers}', 'layers': layers, 'growth_rate': kwargs.get('growth_rate', 32), 'bn_size': kwargs.get('bn_size', 4), 'compression': kwargs.get('compression', 0.5), 'activation': kwargs.get('activation', 'relu'), 'dropout': kwargs.get('dropout', 0.0), } super().__init__(config)
[docs] def build(self, num_classes: int = 1000, **kwargs) -> nn.Module: config = self.get_config() return DenseNetModel( layers=config['layers'], num_classes=num_classes, growth_rate=config.get('growth_rate', 32), bn_size=config.get('bn_size', 4), compression=config.get('compression', 0.5), activation=config.get('activation', 'relu'), dropout=config.get('dropout', 0.0), **kwargs, )
[docs] def densenet( layers: int = 121, num_classes: int = 1000, **kwargs, ) -> nn.Module: """Create a DenseNet model. Args: layers: Number of layers (121, 169, 201, 264) num_classes: Number of output classes **kwargs: Additional configuration (growth_rate, compression, etc.) Returns: DenseNet model Example: >>> model = densenet(layers=121, num_classes=10) >>> model = densenet(layers=201, growth_rate=48) """ return DenseNetModel(layers=layers, num_classes=num_classes, **kwargs)