Source code for torchvision_customizer.compose.stem

"""Stem builder for network entry point.

The Stem is the initial feature extraction stage of a network,
typically consisting of a large convolution followed by pooling.
"""

from typing import Optional, Tuple, Union
import torch
import torch.nn as nn

from torchvision_customizer.compose.operators import ComposableModule
from torchvision_customizer.blocks import ConvBlock
from torchvision_customizer.layers import get_activation, get_normalization, get_pooling


[docs] class Stem(ComposableModule): """Network stem (entry point) builder. Creates the initial feature extraction layers, typically: - Large kernel convolution (e.g., 7x7 with stride 2) - Batch normalization - Activation - Optional pooling Args: in_channels: Input channels (default 3 for RGB) channels: Output channels kernel: Kernel size (default 7) stride: Stride (default 2) padding: Padding (default auto-calculated) activation: Activation function name norm: Normalization type ('batch', 'layer', 'group', 'instance', None) pool: Pooling type ('max', 'avg', None) pool_kernel: Pooling kernel size pool_stride: Pooling stride Example: >>> stem = Stem(channels=64, kernel=7, stride=2) >>> # Creates: Conv(3->64, 7x7, s2) -> BN -> ReLU -> MaxPool(3x3, s2) >>> stem = Stem(channels=64, kernel=3, stride=1, pool=None) >>> # Creates: Conv(3->64, 3x3, s1) -> BN -> ReLU """ def __init__( self, channels: int, in_channels: int = 3, kernel: int = 7, stride: int = 2, padding: Optional[int] = None, activation: str = 'relu', norm: str = 'batch', pool: Optional[str] = 'max', pool_kernel: int = 3, pool_stride: int = 2, pool_padding: int = 1, # Aliases for convenience kernel_size: Optional[int] = None, ): super().__init__() # Handle kernel_size alias if kernel_size is not None: kernel = kernel_size # Auto-calculate padding for same output spatial dim (before stride) if padding is None: padding = kernel // 2 self.in_channels = in_channels self._out_channels = channels # Build stem layers layers = [] # Main convolution layers.append(nn.Conv2d( in_channels, channels, kernel_size=kernel, stride=stride, padding=padding, bias=(norm is None) # No bias if using normalization )) # Normalization if norm: layers.append(get_normalization(norm, channels)) # Activation if activation: layers.append(get_activation(activation)) # Pooling if pool: if pool == 'max': layers.append(nn.MaxPool2d( kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding )) elif pool == 'avg': layers.append(nn.AvgPool2d( kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding )) self.stem = nn.Sequential(*layers) # Store config for repr self._config = { 'in_channels': in_channels, 'channels': channels, 'kernel': kernel, 'stride': stride, 'activation': activation, 'norm': norm, 'pool': pool, }
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.stem(x)
def __repr__(self) -> str: parts = [f"Conv({self._config['in_channels']}->{self._config['channels']}, " f"{self._config['kernel']}x{self._config['kernel']}, s{self._config['stride']})"] if self._config['norm']: parts.append(self._config['norm'].upper()) if self._config['activation']: parts.append(self._config['activation'].upper()) if self._config['pool']: parts.append(f"{self._config['pool'].upper()}Pool") return f"Stem({' -> '.join(parts)})"
[docs] class SimpleStem(ComposableModule): """Simple stem for small images (e.g., CIFAR). Uses a single 3x3 convolution without pooling. Args: channels: Output channels in_channels: Input channels (default 3) activation: Activation function norm: Normalization type """ def __init__( self, channels: int, in_channels: int = 3, activation: str = 'relu', norm: str = 'batch', ): super().__init__() self._out_channels = channels layers = [ nn.Conv2d(in_channels, channels, 3, padding=1, bias=(norm is None)) ] if norm: layers.append(get_normalization(norm, channels)) if activation: layers.append(get_activation(activation)) self.stem = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.stem(x)
def __repr__(self) -> str: return f"SimpleStem({self._out_channels})"
[docs] class PatchStem(ComposableModule): """Patch embedding stem for Vision Transformer style models. Divides input into patches and projects them to embedding dimension. Args: embed_dim: Embedding dimension patch_size: Size of each patch in_channels: Input channels (default 3) """ def __init__( self, embed_dim: int, patch_size: int = 16, in_channels: int = 3, ): super().__init__() self._out_channels = embed_dim self.patch_size = patch_size # Patch embedding via convolution self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, H, W) -> (B, E, H/P, W/P) return self.proj(x)
def __repr__(self) -> str: return f"PatchStem(dim={self._out_channels}, patch={self.patch_size})"