Source code for torchvision_customizer.layers.activations

"""Activation function utilities for torchvision-customizer.

This module provides a flexible interface for working with various activation
functions in PyTorch. It includes a registry-based system for managing different
activation types and a factory function for creating activation modules.

Supported Activation Functions:
    - relu: Rectified Linear Unit
    - leaky_relu: Leaky ReLU with configurable negative slope
    - prelu: Parametric ReLU with learnable negative slope
    - gelu: Gaussian Error Linear Unit
    - silu: Sigmoid Linear Unit (Swish)
    - sigmoid: Sigmoid activation
    - tanh: Hyperbolic tangent
    - elu: Exponential Linear Unit
    - selu: Scaled Exponential Linear Unit

Example:
    >>> import torch
    >>> from torchvision_customizer.layers import get_activation
    >>>
    >>> # Create a ReLU activation
    >>> relu = get_activation('relu')
    >>> x = torch.randn(2, 64, 32, 32)
    >>> output = relu(x)
    >>>
    >>> # Create a Leaky ReLU with custom slope
    >>> leaky_relu = get_activation('leaky_relu', negative_slope=0.2)
    >>> output = leaky_relu(x)
    >>>
    >>> # Case-insensitive
    >>> gelu = get_activation('GELU')
    >>> output = gelu(x)
"""

import torch.nn as nn
from typing import Dict, Optional, Type, Any


# Registry of available activation functions
ACTIVATION_REGISTRY: Dict[str, Type[nn.Module]] = {
    'relu': nn.ReLU,
    'leaky_relu': nn.LeakyReLU,
    'prelu': nn.PReLU,
    'gelu': nn.GELU,
    'silu': nn.SiLU,
    'sigmoid': nn.Sigmoid,
    'tanh': nn.Tanh,
    'elu': nn.ELU,
    'selu': nn.SELU,
}

# Default parameters for each activation function
ACTIVATION_DEFAULTS: Dict[str, Dict[str, Any]] = {
    'relu': {},
    'leaky_relu': {'negative_slope': 0.01},
    'prelu': {'num_parameters': 1},
    'gelu': {},
    'silu': {},
    'sigmoid': {},
    'tanh': {},
    'elu': {'alpha': 1.0},
    'selu': {},
}


[docs] def get_activation( name: str, **kwargs ) -> nn.Module: """Create an activation function module by name. Factory function that returns a configured activation module based on the provided name. Supports case-insensitive activation names and optional keyword arguments for fine-tuning activation parameters. Args: name: Name of the activation function. Case-insensitive. Supported values: 'relu', 'leaky_relu', 'prelu', 'gelu', 'silu', 'sigmoid', 'tanh', 'elu', 'selu' **kwargs: Additional keyword arguments to pass to the activation function constructor. These override default parameters. Returns: An instantiated nn.Module activation function. Raises: ValueError: If the activation function name is not supported. TypeError: If invalid keyword arguments are provided for the activation function. Examples: >>> # Basic usage with default parameters >>> relu = get_activation('relu') >>> leaky_relu = get_activation('leaky_relu') >>> # Custom parameters >>> leaky_relu = get_activation('leaky_relu', negative_slope=0.2) >>> elu = get_activation('elu', alpha=0.5) >>> # Case-insensitive >>> gelu = get_activation('GELU') >>> silu = get_activation('SiLU') >>> # Use in a model >>> import torch >>> model = nn.Sequential( ... nn.Conv2d(3, 64, 3), ... nn.BatchNorm2d(64), ... get_activation('relu') ... ) >>> x = torch.randn(2, 3, 32, 32) >>> output = model(x) """ # Normalize name to lowercase for case-insensitive lookup normalized_name = name.lower().strip() # Check if activation is supported if normalized_name not in ACTIVATION_REGISTRY: supported = ', '.join(sorted(ACTIVATION_REGISTRY.keys())) raise ValueError( f"Unsupported activation function: '{name}'\n" f"Supported activations: {supported}" ) # Get the activation class activation_class = ACTIVATION_REGISTRY[normalized_name] # Get default parameters default_params = ACTIVATION_DEFAULTS[normalized_name].copy() # Override with provided keyword arguments default_params.update(kwargs) try: # Create and return the activation module return activation_class(**default_params) except TypeError as e: raise TypeError( f"Invalid parameters for {normalized_name}: {str(e)}\n" f"Default parameters: {ACTIVATION_DEFAULTS[normalized_name]}" ) from e
[docs] def is_activation_supported(name: str) -> bool: """Check if an activation function is supported. Args: name: Name of the activation function to check. Returns: True if the activation function is supported, False otherwise. Example: >>> is_activation_supported('relu') True >>> is_activation_supported('unsupported_activation') False """ return name.lower().strip() in ACTIVATION_REGISTRY
[docs] def get_supported_activations() -> list[str]: """Get list of all supported activation functions. Returns: A sorted list of supported activation function names. Example: >>> activations = get_supported_activations() >>> print(activations) ['elu', 'gelu', 'leaky_relu', 'prelu', 'relu', 'selu', 'sigmoid', 'silu', 'tanh'] """ return sorted(ACTIVATION_REGISTRY.keys())
# Alias for backward compatibility get_supported_activations.__doc__ = """Get list of all supported activation functions. Returns: A sorted list of supported activation function names. """
[docs] class ActivationFactory: """Factory class for creating and managing activation functions. Provides a stateful interface for creating activation functions with configuration management. Example: >>> factory = ActivationFactory() >>> relu = factory.create('relu') >>> gelu = factory.create('gelu') >>> supported = factory.supported_activations() """
[docs] @staticmethod def create(name: str, **kwargs) -> nn.Module: """Create an activation function module. Args: name: Name of the activation function. **kwargs: Additional keyword arguments for the activation. Returns: The created activation module. Raises: ValueError: If activation name is not supported. """ return get_activation(name, **kwargs)
[docs] @staticmethod def is_supported(name: str) -> bool: """Check if an activation function is supported. Args: name: Name of the activation function. Returns: True if supported, False otherwise. """ return is_activation_supported(name)
[docs] @staticmethod def supported_activations() -> list[str]: """Get list of supported activation functions. Returns: List of supported activation names. """ return get_supported_activations()
[docs] @staticmethod def get_defaults(name: str) -> Dict[str, Any]: """Get default parameters for an activation function. Args: name: Name of the activation function. Returns: Dictionary of default parameters. Raises: ValueError: If activation name is not supported. """ normalized_name = name.lower().strip() if normalized_name not in ACTIVATION_REGISTRY: supported = ', '.join(sorted(ACTIVATION_REGISTRY.keys())) raise ValueError( f"Unsupported activation function: '{name}'\n" f"Supported activations: {supported}" ) return ACTIVATION_DEFAULTS[normalized_name].copy()