Source code for torchvision_customizer.blocks.residual_block

"""Residual block with skip connections.

Implements residual learning as in ResNet. Supports both basic and
bottleneck architectures with optional channel dimension matching.

Example:
    >>> from torchvision_customizer.blocks import ResidualBlock
    >>> res_block = ResidualBlock(in_channels=64, out_channels=64)
    >>> output = res_block(x)
"""

import torch
import torch.nn as nn
from torchvision_customizer.layers import get_activation, get_normalization
from torchvision_customizer.blocks.se_block import SEBlock


[docs] class ResidualBlock(nn.Module): """Residual block with skip connections. Implements residual learning as in ResNet. Supports both basic and bottleneck architectures with optional SE blocks for channel attention and flexible normalization options. Args: in_channels: Number of input channels. out_channels: Number of output channels. stride: Stride for convolution. Default is 1. downsample: Whether to downsample. Default is False. activation: Activation function. Default is 'relu'. norm_type: Normalization type. Default is 'batch'. use_se: Whether to use SE block. Default is False. se_reduction: SE reduction ratio. Default is 16. bottleneck: Whether to use bottleneck architecture. Default is False. Example: >>> # Basic residual block >>> res_block = ResidualBlock(in_channels=64, out_channels=64) >>> x = torch.randn(2, 64, 32, 32) >>> output = res_block(x) >>> print(output.shape) torch.Size([2, 64, 32, 32]) >>> # With SE block and downsample >>> res_block = ResidualBlock( ... in_channels=64, out_channels=128, ... stride=2, downsample=True, use_se=True ... ) """ def __init__( self, in_channels: int, out_channels: int, stride: int = 1, downsample: bool = False, activation: str = 'relu', norm_type: str = 'batch', use_se: bool = False, se_reduction: int = 16, bottleneck: bool = False, ) -> None: """Initialize ResidualBlock.""" super().__init__() self.stride = stride self.downsample_layer = None self.activation_fn = get_activation(activation) self.use_se = use_se if bottleneck: # 1x1 -> 3x3 -> 1x1 pattern bottleneck_channels = max(1, out_channels // 4) self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False) self.bn1 = get_normalization(norm_type, bottleneck_channels) self.conv2 = nn.Conv2d( bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1, bias=False, ) self.bn2 = get_normalization(norm_type, bottleneck_channels) self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False) self.bn3 = get_normalization(norm_type, out_channels) else: # 3x3 -> 3x3 pattern self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False, ) self.bn1 = get_normalization(norm_type, out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = get_normalization(norm_type, out_channels) # Dimension matching if stride != 1 or in_channels != out_channels or downsample: self.downsample_layer = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), get_normalization(norm_type, out_channels), ) # SE block if use_se: self.se_block = SEBlock(out_channels, reduction=se_reduction, activation=activation)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply residual block. Args: x: Input tensor of shape (B, C, H, W). Returns: Output tensor of shape (B, C', H', W'). """ identity = x # Main path out = self.conv1(x) out = self.bn1(out) out = self.activation_fn(out) out = self.conv2(out) out = self.bn2(out) if hasattr(self, 'conv3'): # Bottleneck out = self.activation_fn(out) out = self.conv3(out) out = self.bn3(out) # Apply SE block if enabled if self.use_se: out = self.se_block(out) # Skip connection if self.downsample_layer is not None: identity = self.downsample_layer(x) out = out + identity out = self.activation_fn(out) return out