Source code for torchvision_customizer.hybrid.builder

"""Hybrid Builder: Build customized models from pre-trained backbones.

The HybridBuilder allows you to:
1. Load any torchvision pre-trained model
2. Apply patches (replace, wrap, inject blocks)
3. Modify the head for your task
4. Preserve maximum weights during customization

Example:
    >>> builder = HybridBuilder()
    >>> model = builder.from_torchvision(
    ...     "resnet50",
    ...     weights="IMAGENET1K_V2",
    ...     patches={
    ...         "layer3": {"wrap": {"type": "SEBlock", "params": {"reduction": 16}}},
    ...         "layer4": {"inject": {"type": "CBAMBlock"}},
    ...     },
    ...     num_classes=100,
    ... )
"""

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import copy
import inspect
import logging
import torch
import torch.nn as nn

from torchvision_customizer.hybrid.extractor import (
    extract_tiers,
    get_backbone_info,
    BACKBONE_CONFIGS,
)
from torchvision_customizer.hybrid.weight_utils import (
    partial_load,
    WeightLoadingReport,
)
from torchvision_customizer import registry

logger = logging.getLogger(__name__)


class PatchSpec:
    """Specification for a patch operation."""
    
    def __init__(
        self,
        operation: str,  # 'replace', 'wrap', 'inject', 'inject_after'
        block_type: str,
        params: Dict[str, Any] = None,
        position: str = 'after',  # 'before', 'after', 'replace'
    ):
        self.operation = operation
        self.block_type = block_type
        self.params = params or {}
        self.position = position
    
    @classmethod
    def from_dict(cls, d: Dict[str, Any]) -> 'PatchSpec':
        """Create PatchSpec from dictionary."""
        for op in ['replace', 'wrap', 'inject', 'inject_after']:
            if op in d:
                spec = d[op]
                if isinstance(spec, str):
                    # Simple form: {"wrap": "SEBlock"}
                    return cls(op, spec)
                elif isinstance(spec, dict):
                    # Full form: {"wrap": {"type": "SEBlock", "params": {...}}}
                    return cls(
                        op,
                        spec.get('type', spec.get('block', '')),
                        spec.get('params', {}),
                        spec.get('position', 'after'),
                    )
        
        raise ValueError(f"Invalid patch specification: {d}")


[docs] class HybridModel(nn.Module): """A hybrid model composed from a pre-trained backbone with custom modifications.""" def __init__( self, stem: nn.Module, stages: nn.ModuleList, head: nn.Module, backbone_name: str, modifications: List[str] = None, ): super().__init__() self.stem = stem self.stages = stages self.head = head self.backbone_name = backbone_name self.modifications = modifications or []
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) for stage in self.stages: x = stage(x) x = self.head(x) return x
[docs] def forward_features(self, x: torch.Tensor) -> torch.Tensor: """Extract features before the classification head.""" x = self.stem(x) for stage in self.stages: x = stage(x) return x
[docs] def get_stage_outputs(self, x: torch.Tensor) -> List[torch.Tensor]: """Get intermediate outputs from each stage (useful for FPN, etc.).""" outputs = [] x = self.stem(x) outputs.append(x) for stage in self.stages: x = stage(x) outputs.append(x) return outputs
[docs] def freeze_backbone(self, unfreeze_stages: List[int] = None) -> 'HybridModel': """Freeze backbone, optionally leaving some stages unfrozen. Args: unfreeze_stages: List of stage indices to keep trainable Returns: Self for chaining """ # Freeze stem for param in self.stem.parameters(): param.requires_grad = False # Freeze/unfreeze stages unfreeze_set = set(unfreeze_stages or []) for i, stage in enumerate(self.stages): freeze = i not in unfreeze_set for param in stage.parameters(): param.requires_grad = not freeze return self
[docs] def unfreeze_all(self) -> 'HybridModel': """Unfreeze all parameters.""" for param in self.parameters(): param.requires_grad = True return self
[docs] def count_parameters(self, trainable_only: bool = False) -> int: """Count model parameters.""" if trainable_only: return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters())
[docs] def explain(self) -> str: """Generate human-readable model description.""" lines = [ "+" + "=" * 60 + "+", "|" + f" HybridModel ({self.backbone_name}) ".center(60) + "|", "+" + "=" * 60 + "+", ] # Stem stem_params = sum(p.numel() for p in self.stem.parameters()) lines.append(f"| Stem: {stem_params:,} params".ljust(61) + "|") # Stages lines.append("|" + " Stages: ".ljust(60) + "|") for i, stage in enumerate(self.stages): stage_params = sum(p.numel() for p in stage.parameters()) lines.append(f"| Stage {i+1}: {stage_params:,} params".ljust(61) + "|") # Head head_params = sum(p.numel() for p in self.head.parameters()) lines.append(f"| Head: {head_params:,} params".ljust(61) + "|") lines.append("+" + "-" * 60 + "+") # Modifications if self.modifications: lines.append("| Modifications:".ljust(61) + "|") for mod in self.modifications[:5]: lines.append(f"| - {mod}".ljust(61) + "|") if len(self.modifications) > 5: lines.append(f"| ... and {len(self.modifications) - 5} more".ljust(61) + "|") lines.append("+" + "-" * 60 + "+") # Totals total = self.count_parameters() trainable = self.count_parameters(trainable_only=True) lines.append(f"| Total: {total:,} | Trainable: {trainable:,}".ljust(61) + "|") lines.append("+" + "=" * 60 + "+") return "\n".join(lines)
def __repr__(self) -> str: return (f"HybridModel(backbone='{self.backbone_name}', " f"stages={len(self.stages)}, " f"mods={len(self.modifications)})")
[docs] class HybridBuilder: """Builder for creating hybrid models from pre-trained backbones. The HybridBuilder provides a fluent API for: - Loading pre-trained torchvision models - Extracting and customizing model tiers - Applying patches (attention, custom blocks) - Replacing the classification head Example: >>> builder = HybridBuilder() >>> >>> # Quick build >>> model = builder.from_torchvision( ... "resnet50", ... weights="IMAGENET1K_V2", ... num_classes=100 ... ) >>> >>> # With patches >>> model = builder.from_torchvision( ... "efficientnet_b4", ... weights="IMAGENET1K_V1", ... patches={ ... "features.5": {"wrap": "SEBlock"}, ... }, ... num_classes=10, ... dropout=0.3, ... ) """ # Supported backbone names SUPPORTED_BACKBONES = list(BACKBONE_CONFIGS.keys()) def __init__(self): """Initialize HybridBuilder.""" self._base_model: Optional[nn.Module] = None self._backbone_name: Optional[str] = None self._tiers: Optional[Dict] = None self._modifications: List[str] = [] @property def supported_backbones(self) -> List[str]: """List of supported backbone names.""" return self.SUPPORTED_BACKBONES
[docs] def from_torchvision( self, backbone_name: str, weights: Optional[str] = "DEFAULT", patches: Optional[Dict[str, Dict]] = None, num_classes: int = 1000, dropout: float = 0.0, freeze_backbone: bool = False, unfreeze_stages: Optional[List[int]] = None, verbose: bool = True, ) -> HybridModel: """Create a hybrid model from a torchvision backbone. Args: backbone_name: Name of torchvision model (e.g., 'resnet50') weights: Weights to load ('DEFAULT', 'IMAGENET1K_V1', 'IMAGENET1K_V2', None) patches: Dictionary of patches to apply num_classes: Number of output classes dropout: Dropout rate for classifier freeze_backbone: Whether to freeze backbone weights unfreeze_stages: If freezing, which stages to keep trainable verbose: Print loading information Returns: HybridModel instance Example: >>> model = builder.from_torchvision( ... "resnet50", ... weights="IMAGENET1K_V2", ... patches={"layer3": {"wrap": "SEBlock"}}, ... num_classes=100, ... ) """ # Import torchvision try: from torchvision import models except ImportError: raise ImportError("torchvision is required for hybrid models") # Validate backbone name if not hasattr(models, backbone_name): raise ValueError( f"Unknown backbone '{backbone_name}'. " f"Supported: {', '.join(self.SUPPORTED_BACKBONES[:10])}..." ) # Load base model model_fn = getattr(models, backbone_name) if weights: # Try to get weights enum weights_enum = self._resolve_weights(backbone_name, weights) self._base_model = model_fn(weights=weights_enum) if verbose: print(f"Loaded {backbone_name} with {weights} weights") else: self._base_model = model_fn(weights=None) if verbose: print(f"Loaded {backbone_name} (random initialization)") self._backbone_name = backbone_name # Extract tiers self._tiers = extract_tiers(self._base_model, backbone_name) # Apply patches if patches: self._apply_patches(patches) # Build hybrid model model = self._build_hybrid_model(num_classes, dropout) # Freeze if requested if freeze_backbone: model.freeze_backbone(unfreeze_stages) return model
def _resolve_weights(self, backbone_name: str, weights: str) -> Any: """Resolve weight string to torchvision weights enum.""" from torchvision import models if weights == "DEFAULT": return "DEFAULT" # Try to find weights enum weights_name = f"{backbone_name.title().replace('_', '')}Weights" # Handle special cases weights_map = { "resnet18": "ResNet18_Weights", "resnet34": "ResNet34_Weights", "resnet50": "ResNet50_Weights", "resnet101": "ResNet101_Weights", "resnet152": "ResNet152_Weights", "efficientnet_b0": "EfficientNet_B0_Weights", "efficientnet_b1": "EfficientNet_B1_Weights", "efficientnet_b2": "EfficientNet_B2_Weights", "efficientnet_b3": "EfficientNet_B3_Weights", "efficientnet_b4": "EfficientNet_B4_Weights", "efficientnet_b5": "EfficientNet_B5_Weights", "efficientnet_b6": "EfficientNet_B6_Weights", "efficientnet_b7": "EfficientNet_B7_Weights", "vgg16": "VGG16_Weights", "vgg19": "VGG19_Weights", "mobilenet_v2": "MobileNet_V2_Weights", "mobilenet_v3_small": "MobileNet_V3_Small_Weights", "mobilenet_v3_large": "MobileNet_V3_Large_Weights", "convnext_tiny": "ConvNeXt_Tiny_Weights", "convnext_small": "ConvNeXt_Small_Weights", "convnext_base": "ConvNeXt_Base_Weights", "convnext_large": "ConvNeXt_Large_Weights", } if backbone_name in weights_map: weights_enum_name = weights_map[backbone_name] if hasattr(models, weights_enum_name): weights_enum = getattr(models, weights_enum_name) if hasattr(weights_enum, weights): return getattr(weights_enum, weights) # Fallback to string return weights def _apply_patches(self, patches: Dict[str, Dict]) -> None: """Apply patches to the extracted tiers.""" for target, patch_dict in patches.items(): patch_spec = PatchSpec.from_dict(patch_dict) self._apply_single_patch(target, patch_spec) def _apply_single_patch(self, target: str, patch: PatchSpec) -> None: """Apply a single patch to a target layer.""" # Find the target module target_module, parent, attr_name = self._find_module(target) if target_module is None: logger.warning(f"Target '{target}' not found, skipping patch") return # Get the block class from registry try: BlockClass = registry.get(patch.block_type) except ValueError: logger.warning(f"Block type '{patch.block_type}' not in registry, skipping") return # Apply based on operation if patch.operation == 'replace': # Replace the module entirely new_module = self._create_replacement(target_module, BlockClass, patch.params) setattr(parent, attr_name, new_module) self._modifications.append(f"Replaced {target} with {patch.block_type}") elif patch.operation == 'wrap': # Wrap the module with attention/block new_module = self._wrap_module(target_module, BlockClass, patch.params) setattr(parent, attr_name, new_module) self._modifications.append(f"Wrapped {target} with {patch.block_type}") elif patch.operation in ['inject', 'inject_after']: # Inject block after the target new_module = self._inject_after(target_module, BlockClass, patch.params) setattr(parent, attr_name, new_module) self._modifications.append(f"Injected {patch.block_type} after {target}") def _find_module(self, target: str) -> Tuple[Optional[nn.Module], Optional[nn.Module], str]: """Find a module by dot-separated path.""" parts = target.split('.') current = self._base_model parent = None attr_name = parts[-1] for i, part in enumerate(parts): if hasattr(current, part): parent = current current = getattr(current, part) attr_name = part elif part.isdigit() and isinstance(current, (nn.Sequential, nn.ModuleList)): idx = int(part) if idx < len(current): parent = current current = current[idx] attr_name = str(idx) else: return None, None, "" else: return None, None, "" return current, parent, attr_name def _create_replacement( self, original: nn.Module, BlockClass: Type[nn.Module], params: Dict[str, Any], ) -> nn.Module: """Create a replacement module with inferred channels.""" # Try to infer channels from original in_channels = self._get_channels(original, 'in') out_channels = self._get_channels(original, 'out') # Detect which parameters the block accepts sig = inspect.signature(BlockClass.__init__) param_names = set(sig.parameters.keys()) # Build params - only pass what the block accepts init_params = dict(params) if 'in_channels' in param_names and 'in_channels' not in init_params and in_channels: init_params['in_channels'] = in_channels if 'out_channels' in param_names and 'out_channels' not in init_params and out_channels: init_params['out_channels'] = out_channels if 'channels' in param_names and 'channels' not in init_params and out_channels: init_params['channels'] = out_channels if 'dim' in param_names and 'dim' not in init_params and out_channels: init_params['dim'] = out_channels # Filter out any params the block doesn't accept init_params = {k: v for k, v in init_params.items() if k in param_names or 'kwargs' in str(sig)} return BlockClass(**init_params) def _wrap_module( self, original: nn.Module, BlockClass: Type[nn.Module], params: Dict[str, Any], ) -> nn.Module: """Wrap a module with an attention/block layer.""" out_channels = self._get_channels(original, 'out') # Build params for wrapper - only pass params the block accepts init_params = dict(params) # Detect which channel parameter the block accepts sig = inspect.signature(BlockClass.__init__) param_names = set(sig.parameters.keys()) if out_channels: # Only pass the parameter that the block actually accepts if 'channels' in param_names and 'channels' not in init_params: init_params['channels'] = out_channels if 'in_channels' in param_names and 'in_channels' not in init_params: init_params['in_channels'] = out_channels if 'out_channels' in param_names and 'out_channels' not in init_params: init_params['out_channels'] = out_channels if 'dim' in param_names and 'dim' not in init_params: init_params['dim'] = out_channels # Filter out any params the block doesn't accept init_params = {k: v for k, v in init_params.items() if k in param_names or 'kwargs' in str(sig)} wrapper = BlockClass(**init_params) # Return sequential: original -> wrapper return nn.Sequential(original, wrapper) def _inject_after( self, original: nn.Module, BlockClass: Type[nn.Module], params: Dict[str, Any], ) -> nn.Module: """Inject a block after the original module.""" # Same as wrap for now return self._wrap_module(original, BlockClass, params) def _get_channels(self, module: nn.Module, which: str = 'out') -> Optional[int]: """Get input or output channels from a module.""" if which == 'out': if hasattr(module, 'out_channels'): return module.out_channels # Check last child children = list(module.children()) for child in reversed(children): ch = self._get_channels(child, 'out') if ch: return ch if isinstance(module, nn.Conv2d): return module.out_channels if isinstance(module, nn.BatchNorm2d): return module.num_features else: # 'in' if hasattr(module, 'in_channels'): return module.in_channels children = list(module.children()) for child in children: ch = self._get_channels(child, 'in') if ch: return ch if isinstance(module, nn.Conv2d): return module.in_channels return None def _build_hybrid_model(self, num_classes: int, dropout: float) -> HybridModel: """Build the final HybridModel.""" # Get tiers stem = self._tiers['stem'] stages = self._tiers['stages'] original_head = self._tiers['head'] # Determine feature dimension for new head feature_dim = self._infer_feature_dim() # Build new head head = self._build_head(feature_dim, num_classes, dropout, original_head) return HybridModel( stem=stem, stages=nn.ModuleList(stages), head=head, backbone_name=self._backbone_name, modifications=self._modifications.copy(), ) def _infer_feature_dim(self) -> int: """Infer feature dimension from the model.""" # Try to get from last stage stages = self._tiers['stages'] if stages: last_stage = stages[-1] ch = self._get_channels(last_stage, 'out') if ch: return ch # Try from original head original_head = self._tiers['head'] for module in original_head.modules(): if isinstance(module, nn.Linear): return module.in_features # Default fallback return 2048 def _build_head( self, feature_dim: int, num_classes: int, dropout: float, original_head: nn.Module, ) -> nn.Module: """Build classification head.""" # Check if original head has pooling has_pool = any( isinstance(m, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d, nn.AvgPool2d)) for m in original_head.modules() ) layers = [] if has_pool: layers.append(nn.AdaptiveAvgPool2d((1, 1))) layers.append(nn.Flatten(1)) if dropout > 0: layers.append(nn.Dropout(p=dropout)) layers.append(nn.Linear(feature_dim, num_classes)) return nn.Sequential(*layers)
[docs] def from_checkpoint( self, checkpoint_path: str, backbone_name: str, patches: Optional[Dict[str, Dict]] = None, num_classes: int = 1000, dropout: float = 0.0, strict: bool = False, ) -> HybridModel: """Create hybrid model from a saved checkpoint. Args: checkpoint_path: Path to checkpoint file backbone_name: Backbone architecture name patches: Patches to apply num_classes: Number of output classes dropout: Dropout rate strict: Whether to strictly match checkpoint keys Returns: HybridModel loaded from checkpoint """ # First create model structure model = self.from_torchvision( backbone_name, weights=None, patches=patches, num_classes=num_classes, dropout=dropout, verbose=False, ) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location='cpu') # Handle different checkpoint formats if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint # Clean up state dict keys (remove 'module.' prefix if present) cleaned_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): k = k[7:] cleaned_state_dict[k] = v # Load with partial matching partial_load(model, cleaned_state_dict, ignore_mismatch=not strict) return model
[docs] @classmethod def list_backbones(cls) -> List[str]: """List all supported backbone architectures.""" return cls.SUPPORTED_BACKBONES
# Convenience function def create_hybrid( backbone: str, num_classes: int = 1000, weights: str = "DEFAULT", patches: Dict[str, Dict] = None, **kwargs, ) -> HybridModel: """Convenience function to create a hybrid model. Args: backbone: Backbone name (e.g., 'resnet50') num_classes: Number of output classes weights: Pretrained weights to load patches: Patches to apply **kwargs: Additional arguments for HybridBuilder Returns: HybridModel instance Example: >>> model = create_hybrid( ... "resnet50", ... num_classes=100, ... patches={"layer3": {"wrap": "se"}} ... ) """ builder = HybridBuilder() return builder.from_torchvision( backbone, weights=weights, patches=patches, num_classes=num_classes, **kwargs, )