Source code for torchvision_customizer.templates.base

"""Base Template class for architecture customization.

Provides the foundation for creating customizable model templates
that can be modified before building.
"""

from typing import Any, Callable, Dict, List, Optional, Type, Union
from copy import deepcopy
import torch.nn as nn


[docs] class Template: """Base class for architecture templates with customization. Templates allow you to start from a known architecture and customize it before building. Example: >>> template = Template.resnet(layers=50) >>> template.replace_activation('gelu') >>> template.add_attention('se', after='conv') >>> model = template.build(num_classes=100) """ # Registry of available templates _templates: Dict[str, Type['Template']] = {} def __init__(self, config: Dict[str, Any]): """Initialize template with configuration. Args: config: Architecture configuration dictionary """ self._config = deepcopy(config) self._modifications = []
[docs] @classmethod def register(cls, name: str): """Decorator to register a template class.""" def decorator(template_cls: Type['Template']) -> Type['Template']: cls._templates[name] = template_cls return template_cls return decorator
[docs] @classmethod def from_name(cls, name: str, **kwargs) -> 'Template': """Create a template by name. Args: name: Template name (e.g., 'resnet', 'vgg') **kwargs: Template-specific arguments Returns: Template instance """ if name not in cls._templates: available = ', '.join(sorted(cls._templates.keys())) raise ValueError(f"Unknown template '{name}'. Available: {available}") return cls._templates[name](**kwargs)
[docs] @classmethod def resnet(cls, layers: int = 50, **kwargs) -> 'Template': """Create a ResNet template.""" from torchvision_customizer.templates.resnet import ResNetTemplate return ResNetTemplate(layers=layers, **kwargs)
[docs] @classmethod def vgg(cls, layers: int = 16, **kwargs) -> 'Template': """Create a VGG template.""" from torchvision_customizer.templates.vgg import VGGTemplate return VGGTemplate(layers=layers, **kwargs)
[docs] @classmethod def mobilenet(cls, version: int = 2, **kwargs) -> 'Template': """Create a MobileNet template.""" from torchvision_customizer.templates.mobilenet import MobileNetTemplate return MobileNetTemplate(version=version, **kwargs)
[docs] @classmethod def densenet(cls, layers: int = 121, **kwargs) -> 'Template': """Create a DenseNet template.""" from torchvision_customizer.templates.densenet import DenseNetTemplate return DenseNetTemplate(layers=layers, **kwargs)
[docs] @classmethod def list_templates(cls) -> List[str]: """List available template names.""" return sorted(cls._templates.keys())
# Customization methods
[docs] def replace_activation(self, activation: str) -> 'Template': """Replace all activation functions. Args: activation: New activation function name Returns: Self for chaining """ self._config['activation'] = activation self._modifications.append(('activation', activation)) return self
[docs] def replace_norm(self, norm_type: str) -> 'Template': """Replace all normalization layers. Args: norm_type: New normalization type ('batch', 'layer', 'group', 'instance') Returns: Self for chaining """ self._config['norm_type'] = norm_type self._modifications.append(('norm', norm_type)) return self
[docs] def add_attention(self, attention_type: str = 'se', after: str = 'block') -> 'Template': """Add attention mechanism. Args: attention_type: Type of attention ('se', 'cbam', 'channel', 'spatial') after: Where to add ('block', 'conv', 'stage') Returns: Self for chaining """ if 'attention' not in self._config: self._config['attention'] = {} self._config['attention']['type'] = attention_type self._config['attention']['position'] = after self._modifications.append(('attention', f'{attention_type} after {after}')) return self
[docs] def modify_stage(self, stage_idx: int, **kwargs) -> 'Template': """Modify a specific stage. Args: stage_idx: Stage index (0-based) **kwargs: Stage modifications (blocks, channels, etc.) Returns: Self for chaining """ if 'stage_modifications' not in self._config: self._config['stage_modifications'] = {} self._config['stage_modifications'][stage_idx] = kwargs self._modifications.append(('stage', f'stage {stage_idx}: {kwargs}')) return self
[docs] def set_stem(self, **kwargs) -> 'Template': """Configure the stem. Args: **kwargs: Stem configuration (channels, kernel, stride, etc.) Returns: Self for chaining """ if 'stem' not in self._config: self._config['stem'] = {} self._config['stem'].update(kwargs) self._modifications.append(('stem', str(kwargs))) return self
[docs] def set_head(self, **kwargs) -> 'Template': """Configure the head. Args: **kwargs: Head configuration (hidden, dropout, etc.) Returns: Self for chaining """ if 'head' not in self._config: self._config['head'] = {} self._config['head'].update(kwargs) return self
[docs] def use_dropout(self, rate: float) -> 'Template': """Set dropout rate. Args: rate: Dropout probability Returns: Self for chaining """ self._config['dropout'] = rate return self
[docs] def scale_channels(self, factor: float) -> 'Template': """Scale all channel counts by a factor. Args: factor: Scaling factor (e.g., 0.5 for half, 2.0 for double) Returns: Self for chaining """ self._config['channel_scale'] = factor self._modifications.append(('scale', f'{factor}x channels')) return self
[docs] def get_config(self) -> Dict[str, Any]: """Get the current configuration.""" return deepcopy(self._config)
[docs] def describe(self) -> str: """Get a human-readable description of the template.""" lines = [f"Template: {self._config.get('name', 'Custom')}"] lines.append("-" * 40) for key, value in self._config.items(): if key not in ['name', 'stage_modifications']: lines.append(f" {key}: {value}") if self._modifications: lines.append("\nModifications:") for mod_type, mod_value in self._modifications: lines.append(f" • {mod_type}: {mod_value}") return "\n".join(lines)
[docs] def build(self, num_classes: int = 1000, **kwargs) -> nn.Module: """Build the model from template. Args: num_classes: Number of output classes **kwargs: Additional build arguments Returns: Built nn.Module """ # Override in subclasses raise NotImplementedError("Subclasses must implement build()")
def __repr__(self) -> str: name = self._config.get('name', 'Template') mods = len(self._modifications) return f"{name}(modifications={mods})"