"""
Attention Mechanisms Module
Comprehensive attention mechanisms for neural networks:
- Channel Attention (Squeeze-and-Excitation style)
- Spatial Attention
- Multi-Head Attention
- Positional Encoding for transformers
Author: torchvision-customizer
License: MIT
"""
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class ChannelAttention(nn.Module):
"""
Channel Attention Mechanism (Squeeze-and-Excitation style).
Recalibrates channel-wise feature responses by explicitly modeling
interdependencies between channels.
Attributes:
channels (int): Number of input channels
reduction (int): Reduction ratio for bottleneck
Examples:
>>> attn = ChannelAttention(channels=64, reduction=16)
>>> x = torch.randn(2, 64, 32, 32)
>>> out = attn(x) # Same shape as input
"""
def __init__(
self,
channels: int,
reduction: int = 16,
activation: str = 'relu',
):
"""
Initialize ChannelAttention.
Args:
channels: Number of input channels
reduction: Reduction ratio for bottleneck layer
activation: Activation function name
Raises:
ValueError: If parameters are invalid
"""
super().__init__()
if channels < 1:
raise ValueError(f"channels must be >= 1, got {channels}")
if reduction < 1:
raise ValueError(f"reduction must be >= 1, got {reduction}")
self.channels = channels
self.reduction = reduction
reduced_channels = max(1, channels // reduction)
# Squeeze and Excitation
self.fc1 = nn.Linear(channels, reduced_channels, bias=True)
self.activation = F.relu if activation == 'relu' else F.gelu
self.fc2 = nn.Linear(reduced_channels, channels, bias=True)
self.sigmoid = nn.Sigmoid()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Attention-scaled tensor of same shape
"""
batch, channels, height, width = x.size()
# Squeeze: Global average pooling
squeeze = F.adaptive_avg_pool2d(x, 1).view(batch, channels)
# Excitation: FC layers
excitation = self.fc1(squeeze)
excitation = self.activation(excitation)
excitation = self.fc2(excitation)
excitation = self.sigmoid(excitation).view(batch, channels, 1, 1)
# Scale
return x * excitation
[docs]
class SpatialAttention(nn.Module):
"""
Spatial Attention Mechanism.
Generates attention maps along the spatial dimension.
Focuses on "where" to pay attention.
Attributes:
kernel_size (int): Kernel size for convolution
Examples:
>>> attn = SpatialAttention(kernel_size=7)
>>> x = torch.randn(2, 64, 32, 32)
>>> out = attn(x) # Same shape as input
"""
def __init__(
self,
kernel_size: int = 7,
):
"""
Initialize SpatialAttention.
Args:
kernel_size: Kernel size for convolution (must be odd)
Raises:
ValueError: If kernel_size is even
"""
super().__init__()
if kernel_size % 2 == 0:
raise ValueError(f"kernel_size must be odd, got {kernel_size}")
self.kernel_size = kernel_size
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels=2,
out_channels=1,
kernel_size=kernel_size,
padding=padding,
bias=False
)
self.sigmoid = nn.Sigmoid()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Attention-scaled tensor of same shape
"""
# Channel statistics
avg_pool = torch.mean(x, dim=1, keepdim=True)
max_pool, _ = torch.max(x, dim=1, keepdim=True)
# Concatenate and convolve
concat = torch.cat([avg_pool, max_pool], dim=1)
attention = self.conv(concat)
attention = self.sigmoid(attention)
# Scale
return x * attention
[docs]
class ChannelSpatialAttention(nn.Module):
"""
Combined Channel and Spatial Attention.
Sequentially applies channel and spatial attention for comprehensive
feature recalibration.
Examples:
>>> attn = ChannelSpatialAttention(channels=64)
>>> x = torch.randn(2, 64, 32, 32)
>>> out = attn(x)
"""
def __init__(
self,
channels: int,
channel_reduction: int = 16,
spatial_kernel: int = 7,
):
"""
Initialize ChannelSpatialAttention.
Args:
channels: Number of input channels
channel_reduction: Channel attention reduction ratio
spatial_kernel: Spatial attention kernel size
"""
super().__init__()
self.channel_attn = ChannelAttention(channels, channel_reduction)
self.spatial_attn = SpatialAttention(spatial_kernel)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with sequential attention."""
x = self.channel_attn(x)
x = self.spatial_attn(x)
return x
[docs]
class MultiHeadAttention(nn.Module):
"""
Multi-Head Self-Attention mechanism (scaled dot-product).
Allows the model to attend to information from different representation
subspaces at different positions.
Attributes:
embed_dim (int): Total embedding dimension
num_heads (int): Number of attention heads
Examples:
>>> attn = MultiHeadAttention(embed_dim=256, num_heads=8)
>>> x = torch.randn(2, 32, 256) # (batch, seq_len, embed_dim)
>>> out = attn(x)
"""
def __init__(
self,
embed_dim: int,
num_heads: int = 8,
dropout: float = 0.0,
bias: bool = True,
batch_first: bool = True,
):
"""
Initialize MultiHeadAttention.
Args:
embed_dim: Total embedding dimension
num_heads: Number of parallel attention heads
dropout: Dropout probability
bias: Whether to use bias
batch_first: Whether batch dimension is first
Raises:
ValueError: If embed_dim is not divisible by num_heads
"""
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim ({embed_dim}) must be divisible by "
f"num_heads ({num_heads})"
)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = dropout
self.batch_first = batch_first
# Linear transformations
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.dropout_layer = nn.Dropout(dropout)
# Scaling factor
self.scale = 1.0 / math.sqrt(self.head_dim)
[docs]
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
query: Query tensor of shape (B, L, E) or (L, B, E)
key: Key tensor (defaults to query if None)
value: Value tensor (defaults to query if None)
attention_mask: Attention mask
Returns:
Tuple of (output, attention_weights)
"""
if key is None:
key = query
if value is None:
value = query
batch_size, seq_len, _ = query.shape if self.batch_first else (query.shape[1], query.shape[0], query.shape[2])
# Linear projections
Q = self.q_proj(query) # (B, L, E)
K = self.k_proj(key)
V = self.v_proj(value)
# Reshape for multi-head
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B, H, L, L)
# Apply mask
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
# Softmax
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout_layer(attn_weights)
# Apply attention to values
context = torch.matmul(attn_weights, V) # (B, H, L, D)
# Reshape back
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.embed_dim)
# Final linear projection
output = self.out_proj(context)
return output, attn_weights
[docs]
class PositionalEncoding(nn.Module):
"""
Positional Encoding for Transformer models.
Provides position information to the model using sinusoidal functions.
Examples:
>>> pos_enc = PositionalEncoding(d_model=256, max_len=512)
>>> x = torch.randn(2, 32, 256)
>>> x = x + pos_enc(x)
"""
def __init__(
self,
d_model: int,
max_len: int = 5000,
dropout: float = 0.1,
):
"""
Initialize PositionalEncoding.
Args:
d_model: Model dimension
max_len: Maximum sequence length
dropout: Dropout probability
"""
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Create positional encodings
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
if d_model % 2 == 1:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
else:
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape (B, L, D)
Returns:
Positional encoding for the sequence
"""
seq_len = x.size(1)
pe = self.pe[:, :seq_len, :]
return self.dropout(pe)
[docs]
class AttentionBlock(nn.Module):
"""
Complete attention block combining multiple attention mechanisms.
Can be used as a building block in neural networks.
Examples:
>>> block = AttentionBlock(
... channels=64,
... use_channel=True,
... use_spatial=True,
... use_cross=False
... )
>>> x = torch.randn(2, 64, 32, 32)
>>> out = block(x)
"""
def __init__(
self,
channels: int,
use_channel: bool = True,
use_spatial: bool = True,
use_cross: bool = False,
channel_reduction: int = 16,
spatial_kernel: int = 7,
):
"""
Initialize AttentionBlock.
Args:
channels: Number of input channels
use_channel: Whether to use channel attention
use_spatial: Whether to use spatial attention
use_cross: Whether to use cross-channel spatial attention
channel_reduction: Channel attention reduction ratio
spatial_kernel: Spatial attention kernel size
"""
super().__init__()
self.use_channel = use_channel
self.use_spatial = use_spatial
self.use_cross = use_cross
if use_channel:
self.channel_attn = ChannelAttention(channels, channel_reduction)
if use_spatial:
self.spatial_attn = SpatialAttention(spatial_kernel)
if use_cross:
# Cross-channel spatial attention
self.cross_attn = nn.Sequential(
nn.Conv2d(channels, channels // 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // 2, channels, 1),
nn.Sigmoid()
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with selected attention mechanisms."""
if self.use_channel:
x = self.channel_attn(x)
if self.use_spatial:
x = self.spatial_attn(x)
if self.use_cross:
x = x * self.cross_attn(x)
return x
# Utility functions
[docs]
def create_attention_map(
x: torch.Tensor,
attention_type: str = 'channel',
) -> torch.Tensor:
"""
Create attention map from input.
Args:
x: Input tensor
attention_type: Type of attention ('channel', 'spatial')
Returns:
Attention map
"""
if attention_type == 'channel':
return torch.mean(x, dim=(2, 3), keepdim=True)
elif attention_type == 'spatial':
return torch.mean(x, dim=1, keepdim=True)
else:
raise ValueError(f"Unknown attention_type: {attention_type}")
[docs]
def apply_attention(
x: torch.Tensor,
attention: torch.Tensor,
) -> torch.Tensor:
"""
Apply attention map to input.
Args:
x: Input tensor
attention: Attention map
Returns:
Attention-weighted tensor
"""
return x * attention
[docs]
def normalize_attention(
attention: torch.Tensor,
) -> torch.Tensor:
"""
Normalize attention to sum to 1.
Args:
attention: Attention map
Returns:
Normalized attention
"""
batch_size = attention.size(0)
return attention.view(batch_size, -1).softmax(dim=1).view_as(attention)