"""Backbone Extractor: Extract and analyze torchvision model structures.
Provides utilities to decompose pre-trained models into their
constituent parts (stem, stages, head) for customization.
Example:
>>> from torchvision.models import resnet50
>>> from torchvision_customizer.hybrid import extract_tiers
>>>
>>> model = resnet50()
>>> tiers = extract_tiers(model)
>>> print(tiers.keys()) # ['stem', 'stages', 'head']
"""
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
# Backbone configurations for supported architectures
BACKBONE_CONFIGS = {
# ResNet family
"resnet18": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"resnet34": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"resnet50": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"resnet101": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"resnet152": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
# Wide ResNet
"wide_resnet50_2": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"wide_resnet101_2": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
# ResNeXt
"resnext50_32x4d": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"resnext101_32x8d": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
"resnext101_64x4d": {"stages": ["layer1", "layer2", "layer3", "layer4"], "stem": ["conv1", "bn1", "relu", "maxpool"], "head": ["avgpool", "fc"]},
# EfficientNet family
"efficientnet_b0": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b1": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b2": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b3": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b4": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b5": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b6": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_b7": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_v2_s": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_v2_m": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"efficientnet_v2_l": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
# VGG family
"vgg11": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg13": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg16": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg19": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg11_bn": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg13_bn": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg16_bn": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"vgg19_bn": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
# DenseNet family
"densenet121": {"stages": ["features"], "stem": [], "head": ["classifier"]},
"densenet169": {"stages": ["features"], "stem": [], "head": ["classifier"]},
"densenet201": {"stages": ["features"], "stem": [], "head": ["classifier"]},
"densenet161": {"stages": ["features"], "stem": [], "head": ["classifier"]},
# MobileNet family
"mobilenet_v2": {"stages": ["features"], "stem": [], "head": ["classifier"]},
"mobilenet_v3_small": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"mobilenet_v3_large": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
# ConvNeXt family
"convnext_tiny": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"convnext_small": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"convnext_base": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
"convnext_large": {"stages": ["features"], "stem": [], "head": ["avgpool", "classifier"]},
# Vision Transformer (partial support)
"vit_b_16": {"stages": ["encoder"], "stem": ["conv_proj", "class_token", "encoder.pos_embedding"], "head": ["heads"]},
"vit_b_32": {"stages": ["encoder"], "stem": ["conv_proj", "class_token", "encoder.pos_embedding"], "head": ["heads"]},
"vit_l_16": {"stages": ["encoder"], "stem": ["conv_proj", "class_token", "encoder.pos_embedding"], "head": ["heads"]},
"vit_l_32": {"stages": ["encoder"], "stem": ["conv_proj", "class_token", "encoder.pos_embedding"], "head": ["heads"]},
# Swin Transformer
"swin_t": {"stages": ["features"], "stem": [], "head": ["avgpool", "head"]},
"swin_s": {"stages": ["features"], "stem": [], "head": ["avgpool", "head"]},
"swin_b": {"stages": ["features"], "stem": [], "head": ["avgpool", "head"]},
}
class BackboneInfo:
"""Information about a backbone architecture."""
def __init__(
self,
name: str,
model: nn.Module,
stem_parts: List[str],
stage_parts: List[str],
head_parts: List[str],
):
self.name = name
self.model = model
self.stem_parts = stem_parts
self.stage_parts = stage_parts
self.head_parts = head_parts
# Analyze model structure
self._analyze_model()
def _analyze_model(self) -> None:
"""Analyze the model structure."""
self.total_params = sum(p.numel() for p in self.model.parameters())
self.trainable_params = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
# Get output channels from stages
self.stage_channels = []
for stage_name in self.stage_parts:
if hasattr(self.model, stage_name):
stage = getattr(self.model, stage_name)
self.stage_channels.append(_get_out_channels(stage))
def __repr__(self) -> str:
return (f"BackboneInfo(name='{self.name}', "
f"params={self.total_params:,}, "
f"stages={len(self.stage_parts)})")
def summary(self) -> str:
"""Generate human-readable summary."""
lines = [
f"Backbone: {self.name}",
"-" * 40,
f"Stem: {', '.join(self.stem_parts) or 'N/A'}",
f"Stages: {', '.join(self.stage_parts)}",
f"Head: {', '.join(self.head_parts)}",
"-" * 40,
f"Parameters: {self.total_params:,}",
f"Stage Channels: {self.stage_channels}",
]
return "\n".join(lines)
def _get_out_channels(module: nn.Module) -> Optional[int]:
"""Try to determine output channels of a module."""
# Check direct attribute
if hasattr(module, 'out_channels'):
return module.out_channels
# Check last child
children = list(module.children())
if children:
last = children[-1]
if hasattr(last, 'out_channels'):
return last.out_channels
# Recurse
return _get_out_channels(last)
# Check for Conv2d
if isinstance(module, nn.Conv2d):
return module.out_channels
return None
[docs]
def get_backbone_info(
model: nn.Module,
backbone_name: Optional[str] = None,
) -> BackboneInfo:
"""Get structural information about a backbone model.
Args:
model: The backbone model
backbone_name: Optional name hint for lookup
Returns:
BackboneInfo with structural details
"""
# Try to determine backbone type
if backbone_name is None:
backbone_name = _infer_backbone_name(model)
if backbone_name and backbone_name in BACKBONE_CONFIGS:
config = BACKBONE_CONFIGS[backbone_name]
return BackboneInfo(
name=backbone_name,
model=model,
stem_parts=config.get("stem", []),
stage_parts=config.get("stages", []),
head_parts=config.get("head", []),
)
# Fallback: auto-detect structure
return _auto_detect_structure(model, backbone_name or "unknown")
def _infer_backbone_name(model: nn.Module) -> Optional[str]:
"""Try to infer backbone name from model class."""
class_name = type(model).__name__.lower()
# Check common patterns
for name in BACKBONE_CONFIGS:
if name.replace("_", "") in class_name.replace("_", ""):
return name
return None
def _auto_detect_structure(model: nn.Module, name: str) -> BackboneInfo:
"""Auto-detect model structure when not in config."""
stem_parts = []
stage_parts = []
head_parts = []
for child_name, child in model.named_children():
child_name_lower = child_name.lower()
# Classify by name
if any(s in child_name_lower for s in ['stem', 'conv1', 'bn1']):
stem_parts.append(child_name)
elif any(s in child_name_lower for s in ['layer', 'stage', 'features', 'blocks', 'encoder']):
stage_parts.append(child_name)
elif any(s in child_name_lower for s in ['fc', 'classifier', 'head', 'avgpool', 'pool']):
head_parts.append(child_name)
else:
# Check by position and type
if isinstance(child, (nn.Linear, nn.AdaptiveAvgPool2d)):
head_parts.append(child_name)
else:
stage_parts.append(child_name)
return BackboneInfo(
name=name,
model=model,
stem_parts=stem_parts,
stage_parts=stage_parts,
head_parts=head_parts,
)
def get_stage_channels(model: nn.Module, backbone_name: Optional[str] = None) -> List[int]:
"""Get output channels for each stage.
Args:
model: The backbone model
backbone_name: Optional name hint
Returns:
List of output channels per stage
"""
info = get_backbone_info(model, backbone_name)
return info.stage_channels
def freeze_backbone(
model: nn.Module,
freeze_stem: bool = True,
freeze_stages: Union[bool, List[int]] = True,
freeze_head: bool = False,
backbone_name: Optional[str] = None,
) -> nn.Module:
"""Freeze parts of the backbone for fine-tuning.
Args:
model: The model to freeze
freeze_stem: Whether to freeze stem
freeze_stages: True to freeze all, or list of stage indices
freeze_head: Whether to freeze head
backbone_name: Optional name hint
Returns:
The model with frozen parameters
Example:
>>> model = resnet50(weights='IMAGENET1K_V2')
>>> # Freeze all but last stage and head
>>> freeze_backbone(model, freeze_stages=[0, 1, 2])
"""
tiers = extract_tiers(model, backbone_name)
# Freeze stem
if freeze_stem:
for param in tiers['stem'].parameters():
param.requires_grad = False
# Freeze stages
if freeze_stages is True:
for stage in tiers['stages']:
for param in stage.parameters():
param.requires_grad = False
elif isinstance(freeze_stages, list):
for i, stage in enumerate(tiers['stages']):
if i in freeze_stages:
for param in stage.parameters():
param.requires_grad = False
# Freeze head
if freeze_head:
for param in tiers['head'].parameters():
param.requires_grad = False
return model
def unfreeze_model(model: nn.Module) -> nn.Module:
"""Unfreeze all parameters in the model.
Args:
model: The model to unfreeze
Returns:
The model with all parameters unfrozen
"""
for param in model.parameters():
param.requires_grad = True
return model