Source code for torchvision_customizer.utils.architecture_search

"""
Architecture Search Utilities

Neural Architecture Search (NAS) utilities for exploring and generating
model architectures programmatically:
- Grid search over architecture hyperparameters
- Random search for architecture discovery
- Architecture factory for quick model generation
- Architecture validation and scoring

Author: torchvision-customizer
License: MIT
"""

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import random
import itertools
import torch
import torch.nn as nn
from dataclasses import dataclass, asdict
from enum import Enum


[docs] class ArchitecturePattern(Enum): """Predefined architecture patterns.""" SEQUENTIAL = 'sequential' RESIDUAL = 'residual' DENSE = 'dense' INCEPTION = 'inception' MIXED = 'mixed'
[docs] @dataclass class ArchitectureConfig: """ Architecture configuration dataclass. Attributes: input_shape: Tuple of (channels, height, width) num_classes: Number of output classes num_conv_blocks: Number of convolutional blocks channels: List of channel sizes or 'auto' kernel_sizes: List of kernel sizes or single value strides: List of strides or single value activations: List of activation names or single name dropout_rates: List of dropout rates or single value use_batchnorm: Whether to use batch normalization pattern: Architecture pattern use_attention: Whether to use attention mechanisms use_residual: Whether to use residual connections use_dense: Whether to use dense connections """ input_shape: Tuple[int, int, int] num_classes: int num_conv_blocks: int = 4 channels: Union[str, List[int]] = 'auto' kernel_sizes: Union[int, List[int]] = 3 strides: Union[int, List[int]] = 1 activations: Union[str, List[str]] = 'relu' dropout_rates: Union[float, List[float]] = 0.0 use_batchnorm: bool = True pattern: str = 'sequential' use_attention: bool = False use_residual: bool = False use_dense: bool = False
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return asdict(self)
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ArchitectureConfig': """Create from dictionary.""" return cls(**data)
[docs] def validate(self) -> bool: """ Validate configuration. Returns: True if valid Raises: ValueError: If invalid """ if len(self.input_shape) != 3: raise ValueError(f"input_shape must have 3 elements, got {self.input_shape}") if self.num_classes < 1: raise ValueError(f"num_classes must be >= 1, got {self.num_classes}") if self.num_conv_blocks < 1: raise ValueError(f"num_conv_blocks must be >= 1, got {self.num_conv_blocks}") return True
[docs] class GridSearch: """ Grid search over architecture hyperparameters. Systematically explores all combinations of provided parameters. Examples: >>> search_space = { ... 'num_conv_blocks': [3, 4, 5], ... 'channels': [['auto'], [[64, 128, 256], [32, 64, 128, 256], [64, 128, 256, 512]]], ... 'activation': ['relu', 'gelu'], ... } >>> searcher = GridSearch(search_space) >>> configs = list(searcher.generate()) >>> len(configs) # 3 * 3 * 2 = 18 """ def __init__( self, search_space: Dict[str, List[Any]], base_config: Optional[Dict[str, Any]] = None, ): """ Initialize GridSearch. Args: search_space: Dictionary mapping parameter names to lists of values base_config: Base configuration to update """ self.search_space = search_space self.base_config = base_config or {} self.num_combinations = self._count_combinations() def _count_combinations(self) -> int: """Count total combinations.""" count = 1 for values in self.search_space.values(): count *= len(values) return count
[docs] def generate(self): """ Generate all configurations. Yields: ArchitectureConfig objects """ keys = self.search_space.keys() value_lists = self.search_space.values() for combination in itertools.product(*value_lists): config_dict = self.base_config.copy() config_dict.update(zip(keys, combination)) yield ArchitectureConfig(**config_dict)
[docs] def generate_with_index(self): """ Generate configurations with their index. Yields: Tuple of (index, ArchitectureConfig) """ for idx, config in enumerate(self.generate(), 1): yield idx, config
[docs] class RandomSearch: """ Random search over architecture hyperparameters. Randomly samples from provided parameter distributions. Examples: >>> search_space = { ... 'num_conv_blocks': [3, 4, 5, 6], ... 'activation': ['relu', 'gelu', 'leaky_relu'], ... 'dropout_rate': [0.0, 0.1, 0.2, 0.3], ... } >>> searcher = RandomSearch(search_space, num_samples=100) >>> configs = list(searcher.generate()) """ def __init__( self, search_space: Dict[str, List[Any]], num_samples: int = 10, base_config: Optional[Dict[str, Any]] = None, seed: Optional[int] = None, ): """ Initialize RandomSearch. Args: search_space: Dictionary mapping parameter names to lists of values num_samples: Number of random samples to generate base_config: Base configuration to update seed: Random seed for reproducibility """ self.search_space = search_space self.num_samples = num_samples self.base_config = base_config or {} if seed is not None: random.seed(seed) torch.manual_seed(seed)
[docs] def generate(self): """ Generate random configurations. Yields: ArchitectureConfig objects """ for _ in range(self.num_samples): config_dict = self.base_config.copy() for key, values in self.search_space.items(): config_dict[key] = random.choice(values) yield ArchitectureConfig(**config_dict)
[docs] def generate_with_index(self): """ Generate random configurations with their index. Yields: Tuple of (index, ArchitectureConfig) """ for idx, config in enumerate(self.generate(), 1): yield idx, config
[docs] class ArchitectureFactory: """ Factory for generating architectures from patterns and configurations. Simplifies creating models from predefined patterns. Examples: >>> factory = ArchitectureFactory() >>> config = ArchitectureConfig( ... input_shape=(3, 224, 224), ... num_classes=1000, ... pattern='residual' ... ) >>> architecture_dict = factory.create(config) """ def __init__(self): """Initialize ArchitectureFactory.""" self.patterns = {} self._register_default_patterns() def _register_default_patterns(self): """Register default architecture patterns.""" self.patterns['sequential'] = self._create_sequential self.patterns['residual'] = self._create_residual self.patterns['dense'] = self._create_dense self.patterns['inception'] = self._create_inception self.patterns['mixed'] = self._create_mixed
[docs] def register_pattern( self, name: str, builder: Callable[[ArchitectureConfig], Dict[str, Any]], ): """ Register custom architecture pattern. Args: name: Pattern name builder: Function that creates architecture from config """ self.patterns[name] = builder
[docs] def create(self, config: ArchitectureConfig) -> Dict[str, Any]: """ Create architecture from configuration. Args: config: Architecture configuration Returns: Dictionary with architecture information Raises: ValueError: If pattern is not recognized """ config.validate() pattern = config.pattern if pattern not in self.patterns: raise ValueError(f"Unknown pattern: {pattern}") builder = self.patterns[pattern] return builder(config)
def _create_sequential(self, config: ArchitectureConfig) -> Dict[str, Any]: """Create sequential architecture.""" return { 'type': 'sequential', 'num_blocks': config.num_conv_blocks, 'channels': config.channels, 'kernel_sizes': config.kernel_sizes, 'activations': config.activations, 'dropout_rates': config.dropout_rates, 'use_batchnorm': config.use_batchnorm, } def _create_residual(self, config: ArchitectureConfig) -> Dict[str, Any]: """Create residual architecture.""" return { 'type': 'residual', 'num_blocks': config.num_conv_blocks, 'channels': config.channels, 'kernel_sizes': config.kernel_sizes, 'activations': config.activations, 'use_residual': True, 'skip_pattern': 'residual', } def _create_dense(self, config: ArchitectureConfig) -> Dict[str, Any]: """Create dense architecture.""" return { 'type': 'dense', 'num_blocks': config.num_conv_blocks, 'growth_rate': 32, 'compression': 0.5, 'use_dense': True, } def _create_inception(self, config: ArchitectureConfig) -> Dict[str, Any]: """Create inception architecture.""" return { 'type': 'inception', 'num_blocks': config.num_conv_blocks, 'branch_ratios': [0.5, 0.25, 0.125, 0.125], } def _create_mixed(self, config: ArchitectureConfig) -> Dict[str, Any]: """Create mixed architecture.""" return { 'type': 'mixed', 'num_blocks': config.num_conv_blocks, 'patterns': ['sequential', 'residual', 'dense', 'sequential'], }
[docs] class ArchitectureScorer: """ Score architectures based on various metrics. Evaluates architectures by parameter count, FLOPs, memory usage, etc. Examples: >>> scorer = ArchitectureScorer() >>> config = ArchitectureConfig( ... input_shape=(3, 224, 224), ... num_classes=1000, ... num_conv_blocks=4 ... ) >>> score = scorer.score_config(config) """ def __init__( self, max_parameters: Optional[int] = None, max_memory_mb: Optional[float] = None, target_flops: Optional[float] = None, ): """ Initialize ArchitectureScorer. Args: max_parameters: Maximum allowed parameters max_memory_mb: Maximum allowed memory in MB target_flops: Target FLOPs (for optimization) """ self.max_parameters = max_parameters or 100_000_000 # 100M default self.max_memory_mb = max_memory_mb or 1000.0 # 1GB default self.target_flops = target_flops
[docs] def score_config(self, config: ArchitectureConfig) -> float: """ Score architecture configuration. Args: config: Architecture configuration Returns: Score (higher is better) """ config.validate() # Estimate parameters params = self._estimate_parameters(config) # Check constraints if params > self.max_parameters: return 0.0 # Calculate score (normalize parameters) score = 1.0 - (params / self.max_parameters) return score
def _estimate_parameters(self, config: ArchitectureConfig) -> int: """Estimate number of parameters.""" # Simplified estimation in_channels = config.input_shape[0] params = 0 if isinstance(config.channels, str): # Auto channel generation channels = [64 * (2 ** i) for i in range(config.num_conv_blocks)] else: channels = config.channels for ch in channels: kernel_size = config.kernel_sizes if isinstance(config.kernel_sizes, int) else config.kernel_sizes[0] params += in_channels * ch * kernel_size * kernel_size in_channels = ch # Classifier params += in_channels * config.num_classes return int(params)
[docs] class ArchitectureComparator: """ Compare multiple architectures. Examples: >>> configs = [ ... ArchitectureConfig(...), ... ArchitectureConfig(...), ... ] >>> comparator = ArchitectureComparator(configs) >>> best = comparator.get_best() """ def __init__( self, configs: List[ArchitectureConfig], scorer: Optional[ArchitectureScorer] = None, ): """ Initialize ArchitectureComparator. Args: configs: List of configurations to compare scorer: Architecture scorer (uses default if None) """ self.configs = configs self.scorer = scorer or ArchitectureScorer() self.scores = [self.scorer.score_config(c) for c in configs]
[docs] def get_best(self) -> Tuple[ArchitectureConfig, float]: """Get best configuration.""" best_idx = max(range(len(self.scores)), key=lambda i: self.scores[i]) return self.configs[best_idx], self.scores[best_idx]
[docs] def get_worst(self) -> Tuple[ArchitectureConfig, float]: """Get worst configuration.""" worst_idx = min(range(len(self.scores)), key=lambda i: self.scores[i]) return self.configs[worst_idx], self.scores[worst_idx]
[docs] def get_top_k(self, k: int = 3) -> List[Tuple[ArchitectureConfig, float]]: """Get top k configurations.""" sorted_indices = sorted(range(len(self.scores)), key=lambda i: self.scores[i], reverse=True) return [(self.configs[i], self.scores[i]) for i in sorted_indices[:k]]
[docs] def get_statistics(self) -> Dict[str, float]: """Get score statistics.""" return { 'mean': sum(self.scores) / len(self.scores), 'min': min(self.scores), 'max': max(self.scores), 'std': (sum((s - (sum(self.scores) / len(self.scores))) ** 2 for s in self.scores) / len(self.scores)) ** 0.5, }
# Utility functions
[docs] def expand_config_list( value: Union[Any, List[Any]], length: int, ) -> List[Any]: """ Expand single value or list to specified length. Args: value: Single value or list length: Target length Returns: List of specified length Raises: ValueError: If list length doesn't match target """ if isinstance(value, list): if len(value) != length: raise ValueError(f"List length {len(value)} doesn't match target {length}") return value else: return [value] * length
[docs] def sample_architecture( num_conv_blocks: int, min_channels: int = 32, max_channels: int = 512, min_kernel: int = 3, max_kernel: int = 7, ) -> ArchitectureConfig: """ Generate random architecture. Args: num_conv_blocks: Number of conv blocks min_channels: Minimum channels max_channels: Maximum channels min_kernel: Minimum kernel size max_kernel: Maximum kernel size Returns: Random ArchitectureConfig """ channels = [random.randint(min_channels, max_channels) for _ in range(num_conv_blocks)] kernel_sizes = [random.randint(min_kernel, max_kernel) for _ in range(num_conv_blocks)] activations = [random.choice(['relu', 'gelu', 'leaky_relu']) for _ in range(num_conv_blocks)] return ArchitectureConfig( input_shape=(3, 224, 224), num_classes=1000, num_conv_blocks=num_conv_blocks, channels=channels, kernel_sizes=kernel_sizes, activations=activations, )