Source code for torchvision_customizer.blocks.se_block

"""Squeeze-and-Excitation (SE) block for channel attention.

Implements the SE block from "Squeeze-and-Excitation Networks"
(https://arxiv.org/abs/1709.01507). Uses channel-wise attention to
recalibrate feature responses adaptively.

Example:
    >>> from torchvision_customizer.blocks import SEBlock
    >>> se_block = SEBlock(channels=64, reduction=16)
    >>> output = se_block(x)
"""

import torch
import torch.nn as nn
from torchvision_customizer.layers import get_activation


[docs] class SEBlock(nn.Module): """Squeeze-and-Excitation block for channel attention. Implements the SE block that uses channel-wise attention to recalibrate feature responses adaptively. Effective for improving model performance with minimal computational overhead. Args: channels: Number of input channels. reduction: Reduction ratio for the bottleneck. Default is 16. activation: Activation function. Default is 'relu'. Example: >>> se_block = SEBlock(channels=64, reduction=16) >>> x = torch.randn(2, 64, 32, 32) >>> output = se_block(x) >>> print(output.shape) torch.Size([2, 64, 32, 32]) """ def __init__( self, channels: int, reduction: int = 16, activation: str = 'relu', ) -> None: """Initialize SEBlock. Args: channels: Number of input channels. reduction: Reduction ratio for the bottleneck layer. activation: Activation function name. Raises: ValueError: If channels or reduction is not positive. """ super().__init__() if channels <= 0: raise ValueError(f"channels must be positive, got {channels}") if reduction <= 0: raise ValueError(f"reduction must be positive, got {reduction}") self.channels = channels self.reduction = reduction self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Linear(channels, max(1, channels // reduction), bias=True), get_activation(activation), nn.Linear(max(1, channels // reduction), channels, bias=True), nn.Sigmoid(), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply SE block. Args: x: Input tensor of shape (B, C, H, W). Returns: Output tensor of shape (B, C, H, W) with same shape as input. """ batch, channels, _, _ = x.size() # Squeeze: global average pooling squeeze = self.squeeze(x).view(batch, channels) # Excitation: FC layers with sigmoid excitation = self.excitation(squeeze).view(batch, channels, 1, 1) # Scale: multiply with input return x * excitation