Source code for torchvision_customizer.recipe.builder

"""Recipe builder.

Builds PyTorch models from parsed recipes.
"""

from typing import Any, Dict, List, Optional, Union
import torch.nn as nn

from torchvision_customizer.recipe.definition import Recipe
from torchvision_customizer.recipe.parser import parse_recipe
from torchvision_customizer.compose import (
    Stem,
    Stage,
    Head,
    Compose,
    VisionModel,
)


[docs] def build_recipe(recipe: Recipe) -> VisionModel: """Build a model from a Recipe object. Args: recipe: Recipe definition Returns: Built VisionModel """ config = parse_recipe(recipe) # 1. Build Stem stem_conf = config['stem'] stem_type = stem_conf.pop('type', 'conv') # 'conv' maps to Stem if stem_type in ['conv', 'stem', 'standard']: # Ensure input channels from recipe if 'in_channels' not in stem_conf: stem_conf['in_channels'] = config['input_shape'][0] stem = Stem(**stem_conf) else: # Fallback for other stem types if implemented (e.g., SimpleStem) # For now, map everything to Stem or raise error if stem_type == 'simple': from torchvision_customizer.compose.stem import SimpleStem stem = SimpleStem(**stem_conf) else: stem = Stem(**stem_conf) # Default fallback # 2. Build Stages stages = [] current_in_channels = stem.out_channels for i, stage_conf in enumerate(config['stages']): stage_type = stage_conf.pop('type', 'residual') # If type is a pattern name, use it as pattern if 'pattern' not in stage_conf: stage_conf['pattern'] = stage_type # Ensure in_channels (Stage tracks it, but good to be explicit) if 'in_channels' not in stage_conf: stage_conf['in_channels'] = current_in_channels stage = Stage(**stage_conf) stages.append(stage) current_in_channels = stage.out_channels # 3. Build Head head_conf = config['head'] head_type = head_conf.pop('type', 'linear') # Pass input features from last stage if not specified if 'in_features' not in head_conf: head_conf['in_features'] = current_in_channels head = Head(**head_conf) return VisionModel(stem, stages, head)
[docs] def build_from_config(config: Dict[str, Any]) -> VisionModel: """Build a model from a raw configuration dictionary.""" recipe = Recipe.from_dict(config) return build_recipe(recipe)