Source code for torchvision_customizer.utils.model_summary

"""Model summary and introspection utilities.

Provides comprehensive tools for analyzing PyTorch models including:
- Layer-by-layer output shape tracking
- Parameter distribution analysis
- Memory usage estimation
- Model structure visualization

Example:
    >>> from torchvision_customizer.utils import print_model_summary
    >>> from torchvision_customizer import CustomCNN
    >>> model = CustomCNN(input_shape=(3, 224, 224), num_classes=1000)
    >>> print_model_summary(model, input_shape=(1, 3, 224, 224))
"""

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn


[docs] def calculate_output_shape( model: nn.Module, input_shape: Tuple[int, ...], device: str = "cpu", ) -> Dict[str, Any]: """Calculate output shape by forward pass with dummy input. Args: model: PyTorch model to analyze. input_shape: Input tensor shape (including batch dimension). device: Device to run the model on ('cpu' or 'cuda'). Returns: Dictionary containing: - 'output_shape': Shape of the output tensor - 'output_size': Total elements in output - 'success': Whether calculation succeeded Raises: ValueError: If input shape is invalid. """ if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3: raise ValueError(f"input_shape must be at least 3D (B, C, H, W), got {input_shape}") try: model = model.to(device) model.eval() with torch.no_grad(): dummy_input = torch.randn(*input_shape, device=device) output = model(dummy_input) return { "output_shape": tuple(output.shape), "output_size": output.numel(), "success": True, } except Exception as e: return { "output_shape": None, "output_size": None, "success": False, "error": str(e), }
[docs] def get_layer_summary( model: nn.Module, input_shape: Tuple[int, ...], device: str = "cpu", ) -> List[Dict[str, Any]]: """Get layer-by-layer summary with output shapes. Hooks into model layers to capture output shapes during forward pass. Args: model: PyTorch model to analyze. input_shape: Input tensor shape (including batch dimension). device: Device to run the model on. Returns: List of dictionaries containing layer information: - 'name': Layer name/path - 'type': Layer type (class name) - 'output_shape': Output shape of the layer - 'parameters': Number of parameters - 'trainable': Whether layer is trainable """ if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3: raise ValueError(f"input_shape must be at least 3D (B, C, H, W), got {input_shape}") layers_info = [] hooks = [] def get_hook(name: str) -> callable: """Create a hook function for a layer.""" def hook(module: nn.Module, input: Any, output: Any) -> None: layer_info = { "name": name, "type": module.__class__.__name__, "output_shape": tuple(output.shape) if isinstance(output, torch.Tensor) else None, "parameters": sum(p.numel() for p in module.parameters()), "trainable": any(p.requires_grad for p in module.parameters()), } layers_info.append(layer_info) return hook # Register hooks for name, module in model.named_modules(): if name: # Skip the root module hook = module.register_forward_hook(get_hook(name)) hooks.append(hook) # Forward pass try: model = model.to(device) model.eval() with torch.no_grad(): dummy_input = torch.randn(*input_shape, device=device) _ = model(dummy_input) finally: # Remove hooks for hook in hooks: hook.remove() return layers_info
[docs] def count_parameters_by_type(model: nn.Module) -> Dict[str, Dict[str, int]]: """Count parameters grouped by layer type. Args: model: PyTorch model to analyze. Returns: Dictionary mapping layer types to: - 'total': Total parameters of this type - 'trainable': Trainable parameters - 'count': Number of layers of this type """ param_count = {} for module in model.modules(): module_type = module.__class__.__name__ if module_type not in param_count: param_count[module_type] = { "total": 0, "trainable": 0, "count": 0, } module_params = sum(p.numel() for p in module.parameters(recurse=False)) trainable_params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad) param_count[module_type]["total"] += module_params param_count[module_type]["trainable"] += trainable_params param_count[module_type]["count"] += 1 return param_count
[docs] def get_memory_usage( model: nn.Module, input_shape: Tuple[int, ...], device: str = "cpu", ) -> Dict[str, Union[float, str]]: """Estimate model memory usage. Args: model: PyTorch model to analyze. input_shape: Input tensor shape (including batch dimension). device: Device to run the model on. Returns: Dictionary containing: - 'parameter_memory_mb': Memory used by parameters - 'activation_memory_mb': Estimated activation memory - 'total_memory_mb': Total estimated memory - 'device': Device used for estimation """ if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3: raise ValueError(f"input_shape must be at least 3D, got {input_shape}") # Calculate parameter memory param_size = sum(p.nelement() for p in model.parameters()) * 4 / (1024**2) # 4 bytes per float32 # Estimate activation memory through forward pass model = model.to(device) model.eval() try: with torch.no_grad(): dummy_input = torch.randn(*input_shape, device=device) output = model(dummy_input) activation_size = (dummy_input.nelement() + output.nelement()) * 4 / (1024**2) total_size = param_size + activation_size return { "parameter_memory_mb": round(param_size, 4), "activation_memory_mb": round(activation_size, 4), "total_memory_mb": round(total_size, 4), "device": device, } except Exception as e: return { "parameter_memory_mb": round(param_size, 4), "activation_memory_mb": None, "total_memory_mb": None, "device": device, "error": str(e), }
[docs] def get_model_flops( model: nn.Module, input_shape: Tuple[int, ...], ) -> Dict[str, Union[int, str]]: """Estimate model FLOPs (floating point operations). Note: This is an estimation based on standard layer operations. Actual FLOPs may vary based on implementation details. Args: model: PyTorch model to analyze. input_shape: Input tensor shape (including batch dimension). Returns: Dictionary containing: - 'total_flops': Total estimated FLOPs - 'total_flops_in_billions': Total FLOPs in billions - 'success': Whether estimation succeeded """ if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3: raise ValueError(f"input_shape must be at least 3D, got {input_shape}") total_flops = 0 for module in model.modules(): if isinstance(module, nn.Conv2d): # Conv2d FLOPs: output_height * output_width * kernel_h * kernel_w * in_channels * out_channels batch_size = input_shape[0] out_channels = module.out_channels kernel_ops = module.kernel_size[0] * module.kernel_size[1] * module.in_channels output_size = (input_shape[2] // module.stride[0]) * (input_shape[3] // module.stride[1]) module_flops = batch_size * output_size * output_size * kernel_ops * out_channels / module.groups total_flops += int(module_flops) elif isinstance(module, nn.Linear): # Linear FLOPs: batch_size * input_features * output_features batch_size = input_shape[0] module_flops = batch_size * module.in_features * module.out_features total_flops += int(module_flops) elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): # BatchNorm FLOPs: roughly proportional to number of elements batch_size = input_shape[0] num_features = module.num_features spatial_size = 1 for s in input_shape[2:]: spatial_size *= s module_flops = batch_size * num_features * spatial_size * 5 # Approximate total_flops += int(module_flops) return { "total_flops": total_flops, "total_flops_in_billions": round(total_flops / 1e9, 4), "success": True, }