Source code for torchvision_customizer.compose.operators

"""Composable module base class with operator overloading.

Provides the foundation for the >> (compose), + (sequential),
* (repeat), and | (branch) operators.
"""

from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn


[docs] class ComposableModule(nn.Module): """Base class for composable neural network modules. Supports operator overloading for intuitive model composition: - >> : Compose (sequential flow) - + : Add (sequential combination) - * : Repeat (multiply instances) - | : Branch (parallel paths) Example: >>> block1 = ComposableModule(...) >>> block2 = ComposableModule(...) >>> model = block1 >> block2 # Sequential composition >>> model = block1 + block2 # Same as >> >>> model = block1 * 3 # Repeat 3 times >>> model = block1 | block2 # Parallel branches """ def __init__(self): super().__init__() self._out_channels: Optional[int] = None self._out_shape: Optional[Tuple[int, ...]] = None @property def out_channels(self) -> Optional[int]: """Output channels of this module (if applicable).""" return self._out_channels @property def out_shape(self) -> Optional[Tuple[int, ...]]: """Output shape of this module (if known).""" return self._out_shape def __rshift__(self, other: 'ComposableModule') -> 'ComposedModule': """Compose with >> operator: self >> other.""" return ComposedModule([self, other]) def __add__(self, other: 'ComposableModule') -> 'ComposedModule': """Sequential combination with + operator.""" return ComposedModule([self, other]) def __mul__(self, n: int) -> 'ComposedModule': """Repeat module n times with * operator.""" if not isinstance(n, int) or n < 1: raise ValueError(f"Repeat count must be positive integer, got {n}") # Note: Creates separate instances, not weight sharing return ComposedModule([self] * n) def __rmul__(self, n: int) -> 'ComposedModule': """Support n * module syntax.""" return self.__mul__(n) def __or__(self, other: 'ComposableModule') -> 'BranchedModule': """Parallel branching with | operator.""" return BranchedModule([self, other])
[docs] class ComposedModule(ComposableModule): """Sequential composition of multiple modules. Created automatically when using >> or + operators. """ def __init__(self, modules: List[nn.Module]): super().__init__() self.layers = nn.ModuleList() for module in modules: if isinstance(module, ComposedModule): # Flatten nested compositions self.layers.extend(module.layers) else: self.layers.append(module) # Track output channels from last layer if self.layers and hasattr(self.layers[-1], 'out_channels'): self._out_channels = self.layers[-1].out_channels
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = layer(x) return x
def __rshift__(self, other: nn.Module) -> 'ComposedModule': """Allow chaining: (a >> b) >> c.""" if isinstance(other, ComposedModule): return ComposedModule(list(self.layers) + list(other.layers)) return ComposedModule(list(self.layers) + [other]) def __add__(self, other: nn.Module) -> 'ComposedModule': """Allow chaining with +.""" return self.__rshift__(other) def __repr__(self) -> str: layer_names = [type(l).__name__ for l in self.layers] return f"ComposedModule({' >> '.join(layer_names)})"
[docs] def explain(self) -> str: """Generate human-readable model explanation.""" lines = ["+" + "-" * 60 + "+"] lines.append("|" + "Composed Module".center(60) + "|") lines.append("+" + "-" * 60 + "+") for i, layer in enumerate(self.layers): layer_str = repr(layer) if len(layer_str) > 58: layer_str = layer_str[:55] + "..." lines.append("| " + layer_str.ljust(58) + " |") lines.append("+" + "-" * 60 + "+") # Parameter count total_params = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) lines.append("| " + f"Parameters: {total_params:,} ({trainable:,} trainable)".ljust(58) + " |") lines.append("+" + "-" * 60 + "+") return "\n".join(lines)
[docs] class BranchedModule(ComposableModule): """Parallel branching of multiple modules. Applies input to all branches and concatenates outputs. Created automatically when using | operator. """ def __init__(self, modules: List[nn.Module], concat_dim: int = 1): super().__init__() self.branches = nn.ModuleList() self.concat_dim = concat_dim for module in modules: if isinstance(module, BranchedModule): # Flatten nested branches self.branches.extend(module.branches) else: self.branches.append(module) # Track combined output channels total_channels = 0 for branch in self.branches: if hasattr(branch, 'out_channels') and branch.out_channels: total_channels += branch.out_channels if total_channels > 0: self._out_channels = total_channels
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: outputs = [branch(x) for branch in self.branches] return torch.cat(outputs, dim=self.concat_dim)
def __or__(self, other: nn.Module) -> 'BranchedModule': """Allow chaining: (a | b) | c.""" if isinstance(other, BranchedModule): return BranchedModule(list(self.branches) + list(other.branches)) return BranchedModule(list(self.branches) + [other]) def __repr__(self) -> str: branch_names = [type(b).__name__ for b in self.branches] return f"BranchedModule({' | '.join(branch_names)})"
[docs] def compose(*modules: nn.Module) -> ComposedModule: """Compose multiple modules sequentially. Alternative to >> operator for explicit composition. Example: >>> model = compose(conv1, conv2, conv3, head) """ return ComposedModule(list(modules))
[docs] def repeat(module: nn.Module, n: int) -> ComposedModule: """Repeat a module n times. Note: Creates n separate instances with independent weights. For weight sharing, use a loop in forward. Example: >>> stage = repeat(ResidualBlock(64), 3) """ return module * n
[docs] def branch(*modules: nn.Module, concat_dim: int = 1) -> BranchedModule: """Create parallel branches that concatenate outputs. Alternative to | operator for explicit branching. Example: >>> inception = branch(conv1x1, conv3x3, conv5x5, pool) """ return BranchedModule(list(modules), concat_dim=concat_dim)
[docs] class ResidualWrapper(ComposableModule): """Wrapper to add residual connection around any module.""" def __init__(self, module: nn.Module, projection: nn.Module = None): super().__init__() self.module = module self.projection = projection if hasattr(module, 'out_channels'): self._out_channels = module.out_channels
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: identity = x if self.projection is not None: identity = self.projection(x) return self.module(x) + identity
def __repr__(self) -> str: return f"Residual({self.module})"
[docs] def residual(module: nn.Module, projection: nn.Module = None) -> ResidualWrapper: """Wrap a module with residual connection. Example: >>> block = residual(ConvBlock(64, 64)) """ return ResidualWrapper(module, projection)