Source code for torchvision_customizer.blocks.depthwise_block

"""Depthwise separable convolution block.

Implements depthwise separable convolutions for efficient feature extraction.
Separates spatial and channel-wise convolutions for parameter efficiency.

Example:
    >>> from torchvision_customizer.blocks import DepthwiseBlock
    >>> block = DepthwiseBlock(in_channels=64, out_channels=128)
    >>> output = block(x)
"""

import torch
import torch.nn as nn
from torchvision_customizer.layers import get_activation, get_normalization


[docs] class DepthwiseBlock(nn.Module): """Depthwise separable convolution block. Implements depthwise separable convolutions combining: 1. Depthwise convolution (per-channel spatial convolution) 2. Pointwise convolution (1x1 convolution for channel mixing) This decomposition reduces parameters while maintaining expressiveness. Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Kernel size for depthwise convolution. Default is 3. stride: Stride for convolution. Default is 1. padding: Padding for convolution. Default is 1. activation: Activation function. Default is 'relu'. norm_type: Normalization type. Default is 'batch'. add_residual: Whether to add residual connection. Default is False. Example: >>> # Basic depthwise block >>> block = DepthwiseBlock(in_channels=64, out_channels=128) >>> x = torch.randn(2, 64, 32, 32) >>> output = block(x) >>> print(output.shape) torch.Size([2, 128, 32, 32]) >>> # With residual connection (same channels) >>> block = DepthwiseBlock( ... in_channels=64, out_channels=64, ... kernel_size=5, add_residual=True ... ) """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, activation: str = 'relu', norm_type: str = 'batch', add_residual: bool = False, ) -> None: """Initialize DepthwiseBlock.""" super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.add_residual = add_residual and (in_channels == out_channels) self.activation_fn = get_activation(activation) # Depthwise convolution (groups=in_channels) self.depthwise = nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels, bias=False, ) self.bn1 = get_normalization(norm_type, in_channels) # Pointwise convolution (1x1) self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self.bn2 = get_normalization(norm_type, out_channels)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply depthwise separable convolution. Args: x: Input tensor of shape (B, C_in, H, W). Returns: Output tensor of shape (B, C_out, H', W'). """ identity = x # Depthwise convolution out = self.depthwise(x) out = self.bn1(out) out = self.activation_fn(out) # Pointwise convolution out = self.pointwise(out) out = self.bn2(out) # Residual connection if self.add_residual: out = out + identity out = self.activation_fn(out) return out