Source code for torchvision_customizer.recipe.yaml_loader

"""Enhanced YAML Recipe Loader.

v2.1 features:
- Recipe inheritance (extends)
- Macro expansion (@macro)
- Hybrid backbone support
- JSONSchema validation

Example:
    >>> from torchvision_customizer.recipe import load_yaml_recipe
    >>> 
    >>> # Load and build model
    >>> model = load_yaml_recipe("my_recipe.yaml")
    >>> 
    >>> # Or get config only
    >>> config = load_yaml_config("my_recipe.yaml")
"""

from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import os
import torch.nn as nn

try:
    import yaml
    HAS_YAML = True
except ImportError:
    HAS_YAML = False

from torchvision_customizer.recipe.schema import (
    validate_recipe_config,
    expand_macros,
    merge_recipes,
    get_template,
    RECIPE_TEMPLATES,
    ValidationError,
)


class RecipeLoadError(Exception):
    """Error loading recipe file."""
    pass


[docs] def load_yaml_config( path: Union[str, Path], validate: bool = True, expand: bool = True, ) -> Dict[str, Any]: """Load and process a YAML recipe configuration. Args: path: Path to YAML file validate: Whether to validate against schema expand: Whether to expand macros Returns: Processed recipe configuration Raises: RecipeLoadError: If file cannot be loaded ValidationError: If validation fails """ if not HAS_YAML: raise ImportError("PyYAML is required for YAML recipes. Install with: pip install pyyaml") path = Path(path) if not path.exists(): raise RecipeLoadError(f"Recipe file not found: {path}") # Load YAML try: with open(path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) except yaml.YAMLError as e: raise RecipeLoadError(f"Invalid YAML in {path}: {e}") if config is None: raise RecipeLoadError(f"Empty recipe file: {path}") # Handle inheritance if 'extends' in config: config = _resolve_inheritance(config, path.parent) # Expand macros if expand and 'macros' in config: config = expand_macros(config) # Validate if validate: warnings = validate_recipe_config(config, strict=True) for warning in warnings: print(f"Warning: {warning}") return config
def _resolve_inheritance( config: Dict[str, Any], base_dir: Path, ) -> Dict[str, Any]: """Resolve recipe inheritance chain.""" extends = config.get('extends') if extends is None: return config # Check if it's a built-in template if extends in RECIPE_TEMPLATES: base = get_template(extends) else: # Load from file base_path = base_dir / extends if not base_path.suffix: base_path = base_path.with_suffix('.yaml') if not base_path.exists(): raise RecipeLoadError(f"Cannot find base recipe: {extends}") base = load_yaml_config(base_path, validate=False, expand=False) # Merge base with override return merge_recipes(base, config)
[docs] def load_yaml_recipe( path: Union[str, Path], num_classes: Optional[int] = None, **kwargs, ) -> nn.Module: """Load a YAML recipe and build the model. Args: path: Path to YAML recipe file num_classes: Override number of classes **kwargs: Additional build arguments Returns: Built PyTorch model Example: >>> model = load_yaml_recipe("resnet_cifar.yaml", num_classes=10) """ config = load_yaml_config(path) # Override num_classes if provided if num_classes is not None: if 'head' in config: if isinstance(config['head'], dict): config['head']['num_classes'] = num_classes else: config['head'] = {'num_classes': num_classes} elif 'backbone' in config: config['num_classes'] = num_classes # Build model based on type if 'backbone' in config: return _build_hybrid_model(config, **kwargs) else: return _build_recipe_model(config, **kwargs)
def _build_hybrid_model(config: Dict[str, Any], **kwargs) -> nn.Module: """Build a hybrid model from config.""" from torchvision_customizer.hybrid import HybridBuilder backbone_config = config['backbone'] if isinstance(backbone_config, str): # Parse string form: "resnet50(weights=IMAGENET1K_V2)" import re match = re.match(r'(\w+)(?:\((.+)\))?', backbone_config) if match: name = match.group(1) params_str = match.group(2) backbone_config = {'name': name} if params_str: # Parse simple key=value pairs for pair in params_str.split(','): if '=' in pair: k, v = pair.split('=', 1) backbone_config[k.strip()] = v.strip() builder = HybridBuilder() return builder.from_torchvision( backbone_name=backbone_config['name'], weights=backbone_config.get('weights', 'DEFAULT'), patches=backbone_config.get('patches'), num_classes=config.get('num_classes', config.get('head', {}).get('num_classes', 1000)), dropout=config.get('head', {}).get('dropout', 0.0), **kwargs, ) def _build_recipe_model(config: Dict[str, Any], **kwargs) -> nn.Module: """Build a model from stem/stages/head config.""" from torchvision_customizer.recipe import Recipe, build_recipe recipe = Recipe( stem=config.get('stem', 'conv(64)'), stages=config.get('stages', []), head=config.get('head', 'linear(1000)'), input_shape=tuple(config.get('input_shape', [3, 224, 224])), ) return build_recipe(recipe)
[docs] def save_yaml_recipe( config: Dict[str, Any], path: Union[str, Path], include_metadata: bool = True, ) -> None: """Save a recipe configuration to YAML file. Args: config: Recipe configuration path: Output file path include_metadata: Whether to include schema version comment """ if not HAS_YAML: raise ImportError("PyYAML is required. Install with: pip install pyyaml") path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) content = "" if include_metadata: content = "# torchvision-customizer Recipe v2.1\n" content += "# Schema: https://github.com/codewithdark-git/torchvision-customizer#recipes\n\n" content += yaml.dump(config, default_flow_style=False, sort_keys=False) with open(path, 'w', encoding='utf-8') as f: f.write(content)
[docs] def list_templates() -> List[str]: """List available recipe templates.""" return list(RECIPE_TEMPLATES.keys())
[docs] def create_recipe_from_template( template_name: str, output_path: Optional[Union[str, Path]] = None, **overrides, ) -> Dict[str, Any]: """Create a new recipe from a template. Args: template_name: Name of the template output_path: Optional path to save the recipe **overrides: Values to override in the template Returns: Recipe configuration """ config = get_template(template_name) # Apply overrides for key, value in overrides.items(): if key in config and isinstance(config[key], dict) and isinstance(value, dict): config[key].update(value) else: config[key] = value if output_path: save_yaml_recipe(config, output_path) return config
# Example YAML recipe content for documentation EXAMPLE_RECIPE_YAML = """# Example: Custom ResNet with Attention # ===================================== # This recipe creates a ResNet-50 backbone with SE attention injected # and a custom head for 100-class classification. name: ResNet50-SE-Custom version: "1.0.0" description: ResNet50 with SE attention for custom classification # Macros for reusable values macros: attention: se activation: relu dropout: 0.3 # Use pretrained backbone backbone: name: resnet50 weights: IMAGENET1K_V2 patches: layer3: wrap: type: "@attention" params: reduction: 16 layer4: wrap: type: cbam_block # Custom head head: num_classes: 100 dropout: "@dropout" # Training hints (optional) training: optimizer: adamw learning_rate: 0.001 weight_decay: 0.01 epochs: 100 batch_size: 32 """
[docs] def create_example_recipe(output_path: Union[str, Path]) -> None: """Create an example recipe file for reference. Args: output_path: Where to save the example """ path = Path(output_path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'w', encoding='utf-8') as f: f.write(EXAMPLE_RECIPE_YAML) print(f"Created example recipe at: {path}")