Source code for torchvision_customizer.compose.head

"""Head builder for network output.

The Head is the final classification/regression part of a network,
typically consisting of global pooling, flattening, and linear layers.
"""

from typing import List, Optional, Union
import torch
import torch.nn as nn

from torchvision_customizer.compose.operators import ComposableModule
from torchvision_customizer.layers import get_activation






[docs] class SegmentationHead(ComposableModule): """Segmentation head for dense prediction tasks. Args: num_classes: Number of segmentation classes in_channels: Input channels hidden_channels: Hidden layer channels upsample_factor: Upsampling factor to original resolution """ def __init__( self, num_classes: int, in_channels: Optional[int] = None, hidden_channels: int = 256, upsample_factor: int = 1, ): super().__init__() self.num_classes = num_classes self._out_channels = num_classes self._in_channels = in_channels self._hidden_channels = hidden_channels self._upsample_factor = upsample_factor self._head = None if in_channels is not None: self._head = self._build_head(in_channels) def _build_head(self, in_channels: int) -> nn.Sequential: layers = [ nn.Conv2d(in_channels, self._hidden_channels, 3, padding=1), nn.BatchNorm2d(self._hidden_channels), nn.ReLU(inplace=True), nn.Conv2d(self._hidden_channels, self.num_classes, 1), ] if self._upsample_factor > 1: layers.append(nn.Upsample( scale_factor=self._upsample_factor, mode='bilinear', align_corners=False )) return nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self._head is None: self._head = self._build_head(x.shape[1]) self._head = self._head.to(x.device) return self._head(x)
def __repr__(self) -> str: return f"SegmentationHead(classes={self.num_classes})"
[docs] class DetectionHead(ComposableModule): """Detection head for object detection tasks. Args: num_classes: Number of detection classes num_anchors: Number of anchor boxes per position in_channels: Input channels """ def __init__( self, num_classes: int, num_anchors: int = 9, in_channels: Optional[int] = None, ): super().__init__() self.num_classes = num_classes self.num_anchors = num_anchors self._in_channels = in_channels # Detection head outputs: (class scores, bbox regression) # Per anchor: num_classes + 4 (x, y, w, h) self._out_per_anchor = num_classes + 4 self._head = None if in_channels is not None: self._head = self._build_head(in_channels) def _build_head(self, in_channels: int) -> nn.Module: return nn.Conv2d( in_channels, self.num_anchors * self._out_per_anchor, kernel_size=3, padding=1 )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self._head is None: self._head = self._build_head(x.shape[1]) self._head = self._head.to(x.device) # Output: (B, num_anchors * (num_classes + 4), H, W) return self._head(x)
def __repr__(self) -> str: return f"DetectionHead(classes={self.num_classes}, anchors={self.num_anchors})"