"""YAML Recipe Schema and Validation.
Provides JSONSchema-based validation for YAML recipe files,
along with utilities for schema generation and error reporting.
v2.1 Features:
- JSONSchema validation with helpful error messages
- Recipe inheritance (extend other recipes)
- Macro expansion (@attention, @block, etc.)
- Hybrid backbone support in recipes
Example:
>>> from torchvision_customizer.recipe.schema import validate_recipe_config
>>> config = yaml.safe_load(open("my_recipe.yaml"))
>>> validate_recipe_config(config)
"""
from typing import Any, Dict, List, Optional, Union
import copy
import re
# JSONSchema for v2.1 recipes
RECIPE_SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "torchvision-customizer Recipe Schema v2.1",
"type": "object",
"properties": {
# Metadata
"name": {
"type": "string",
"description": "Recipe name for identification"
},
"version": {
"type": "string",
"pattern": r"^\d+\.\d+(\.\d+)?$",
"description": "Recipe version (semver)"
},
"description": {
"type": "string",
"description": "Recipe description"
},
"extends": {
"type": "string",
"description": "Parent recipe to extend"
},
# Input configuration
"input_shape": {
"type": "array",
"items": {"type": "integer"},
"minItems": 3,
"maxItems": 3,
"default": [3, 224, 224],
"description": "Input tensor shape (C, H, W)"
},
# Macros for reuse
"macros": {
"type": "object",
"additionalProperties": {"type": "string"},
"description": "Macro definitions for @macro substitution"
},
# v2.1: Hybrid backbone support
"backbone": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Torchvision backbone name (e.g., resnet50)"
},
"weights": {
"type": "string",
"description": "Pretrained weights (e.g., IMAGENET1K_V2)"
},
"patches": {
"type": "object",
"description": "Layer patches to apply"
}
},
"required": ["name"]
},
# Architecture components (used when not using backbone)
"stem": {
"oneOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"type": {"type": "string"},
"channels": {"type": "integer", "minimum": 1},
"kernel_size": {"type": "integer", "minimum": 1},
"stride": {"type": "integer", "minimum": 1},
"activation": {"type": "string"},
"norm_type": {"type": "string"}
}
}
],
"description": "Stem configuration (string or object)"
},
"stages": {
"type": "array",
"items": {
"oneOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"type": {"type": "string"},
"pattern": {"type": "string"},
"channels": {"type": "integer", "minimum": 1},
"blocks": {"type": "integer", "minimum": 1},
"downsample": {"type": "boolean"},
"activation": {"type": "string"},
"attention": {"type": "string"},
"drop_path": {"type": "number", "minimum": 0, "maximum": 1}
}
}
]
},
"description": "List of stage configurations"
},
"head": {
"oneOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"type": {"type": "string"},
"num_classes": {"type": "integer", "minimum": 1},
"dropout": {"type": "number", "minimum": 0, "maximum": 1},
"hidden_dims": {
"type": "array",
"items": {"type": "integer"}
}
}
}
],
"description": "Head/classifier configuration"
},
# Training hints (optional)
"training": {
"type": "object",
"properties": {
"optimizer": {"type": "string"},
"learning_rate": {"type": "number"},
"weight_decay": {"type": "number"},
"epochs": {"type": "integer"},
"batch_size": {"type": "integer"}
}
}
},
# Either backbone OR stem/stages/head
"oneOf": [
{"required": ["backbone"]},
{"required": ["stem", "stages", "head"]}
]
}
[docs]
class ValidationError(Exception):
"""Recipe validation error with detailed information."""
def __init__(self, message: str, path: str = "", suggestion: str = ""):
self.message = message
self.path = path
self.suggestion = suggestion
super().__init__(self._format_message())
def _format_message(self) -> str:
msg = f"Validation Error: {self.message}"
if self.path:
msg += f"\n At: {self.path}"
if self.suggestion:
msg += f"\n Suggestion: {self.suggestion}"
return msg
[docs]
def validate_recipe_config(config: Dict[str, Any], strict: bool = True) -> List[str]:
"""Validate a recipe configuration against the schema.
Args:
config: Recipe configuration dictionary
strict: If True, raise on errors; if False, return warnings
Returns:
List of warning messages (empty if all OK)
Raises:
ValidationError: If strict and validation fails
"""
warnings = []
# Check required structure
has_backbone = 'backbone' in config
has_components = all(k in config for k in ['stem', 'stages', 'head'])
if not has_backbone and not has_components:
error = ValidationError(
"Recipe must have either 'backbone' OR 'stem', 'stages', and 'head'",
path="root",
suggestion="Add backbone: {name: resnet50} or define stem/stages/head"
)
if strict:
raise error
warnings.append(str(error))
# Validate backbone
if has_backbone:
backbone = config['backbone']
if isinstance(backbone, str):
# Simple string form: "resnet50(weights=IMAGENET1K_V2)"
pass
elif isinstance(backbone, dict):
if 'name' not in backbone:
error = ValidationError(
"Backbone must have 'name' field",
path="backbone",
suggestion="Add name: resnet50"
)
if strict:
raise error
warnings.append(str(error))
# Validate stages
if 'stages' in config:
stages = config['stages']
if not isinstance(stages, list):
error = ValidationError(
"stages must be a list",
path="stages",
suggestion="stages should be [stage1, stage2, ...]"
)
if strict:
raise error
warnings.append(str(error))
else:
for i, stage in enumerate(stages):
stage_warnings = _validate_stage(stage, i, strict)
warnings.extend(stage_warnings)
# Validate input_shape
if 'input_shape' in config:
shape = config['input_shape']
if not isinstance(shape, (list, tuple)) or len(shape) != 3:
error = ValidationError(
"input_shape must be [C, H, W]",
path="input_shape",
suggestion="Use input_shape: [3, 224, 224]"
)
if strict:
raise error
warnings.append(str(error))
# Validate macros
if 'macros' in config:
if not isinstance(config['macros'], dict):
error = ValidationError(
"macros must be a dictionary",
path="macros",
suggestion="macros: {attention: SEBlock, block: residual}"
)
if strict:
raise error
warnings.append(str(error))
return warnings
def _validate_stage(stage: Union[str, Dict], index: int, strict: bool) -> List[str]:
"""Validate a single stage configuration."""
warnings = []
path = f"stages[{index}]"
if isinstance(stage, str):
# String definition - will be parsed later
return warnings
if not isinstance(stage, dict):
error = ValidationError(
f"Stage must be string or dict, got {type(stage).__name__}",
path=path
)
if strict:
raise error
warnings.append(str(error))
return warnings
# Check channels
if 'channels' in stage:
ch = stage['channels']
if not isinstance(ch, int) or ch < 1:
error = ValidationError(
"channels must be positive integer",
path=f"{path}.channels"
)
if strict:
raise error
warnings.append(str(error))
# Check blocks
if 'blocks' in stage:
blocks = stage['blocks']
if not isinstance(blocks, int) or blocks < 1:
error = ValidationError(
"blocks must be positive integer",
path=f"{path}.blocks"
)
if strict:
raise error
warnings.append(str(error))
return warnings
[docs]
def expand_macros(config: Dict[str, Any]) -> Dict[str, Any]:
"""Expand macros in recipe configuration.
Replaces @macro patterns with their definitions.
Args:
config: Recipe configuration with macros
Returns:
Configuration with macros expanded
Example:
>>> config = {
... 'macros': {'attention': 'SEBlock'},
... 'stages': [{'attention': '@attention'}]
... }
>>> expanded = expand_macros(config)
>>> # expanded['stages'][0]['attention'] == 'SEBlock'
"""
if 'macros' not in config:
return config
macros = config['macros']
result = copy.deepcopy(config)
def replace_in_value(value):
if isinstance(value, str):
# Check for @macro pattern
if value.startswith('@'):
macro_name = value[1:]
if macro_name in macros:
return macros[macro_name]
return value
elif isinstance(value, dict):
return {k: replace_in_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [replace_in_value(item) for item in value]
return value
# Expand macros in all fields except 'macros' itself
for key, value in result.items():
if key != 'macros':
result[key] = replace_in_value(value)
return result
[docs]
def merge_recipes(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two recipe configurations (for inheritance).
Args:
base: Base recipe configuration
override: Recipe to merge on top
Returns:
Merged configuration
"""
result = copy.deepcopy(base)
for key, value in override.items():
if key == 'extends':
# Don't include extends in result
continue
elif key == 'stages' and key in result:
# For stages, replace or extend based on index
if isinstance(value, list):
result[key] = value # Replace stages
elif key == 'macros' and key in result:
# Merge macros
result[key] = {**result[key], **value}
elif isinstance(value, dict) and key in result and isinstance(result[key], dict):
# Deep merge dictionaries
result[key] = merge_recipes(result[key], value)
else:
result[key] = copy.deepcopy(value)
return result
def generate_schema_docs() -> str:
"""Generate human-readable documentation from schema."""
lines = [
"# Recipe Schema Documentation",
"",
"## Root Properties",
""
]
props = RECIPE_SCHEMA.get('properties', {})
for name, spec in props.items():
desc = spec.get('description', 'No description')
ptype = spec.get('type', spec.get('oneOf', 'mixed'))
lines.append(f"### `{name}`")
lines.append(f"- Type: `{ptype}`")
lines.append(f"- Description: {desc}")
if 'default' in spec:
lines.append(f"- Default: `{spec['default']}`")
lines.append("")
return "\n".join(lines)
# Common recipe templates
RECIPE_TEMPLATES = {
"resnet_base": {
"name": "ResNet Base",
"input_shape": [3, 224, 224],
"stem": {"type": "conv", "channels": 64, "kernel_size": 7, "stride": 2},
"stages": [
{"pattern": "residual", "channels": 64, "blocks": 2},
{"pattern": "residual", "channels": 128, "blocks": 2, "downsample": True},
{"pattern": "residual", "channels": 256, "blocks": 2, "downsample": True},
{"pattern": "residual", "channels": 512, "blocks": 2, "downsample": True},
],
"head": {"type": "linear", "num_classes": 1000}
},
"efficientnet_base": {
"name": "EfficientNet Base",
"input_shape": [3, 224, 224],
"macros": {"attention": "se", "activation": "swish"},
"stem": {"type": "conv", "channels": 32, "kernel_size": 3, "stride": 2, "activation": "@activation"},
"stages": [
{"pattern": "mbconv", "channels": 16, "blocks": 1, "attention": "@attention"},
{"pattern": "mbconv", "channels": 24, "blocks": 2, "downsample": True, "attention": "@attention"},
{"pattern": "mbconv", "channels": 40, "blocks": 2, "downsample": True, "attention": "@attention"},
{"pattern": "mbconv", "channels": 80, "blocks": 3, "downsample": True, "attention": "@attention"},
{"pattern": "mbconv", "channels": 112, "blocks": 3, "attention": "@attention"},
{"pattern": "mbconv", "channels": 192, "blocks": 4, "downsample": True, "attention": "@attention"},
{"pattern": "mbconv", "channels": 320, "blocks": 1, "attention": "@attention"},
],
"head": {"type": "linear", "num_classes": 1000, "dropout": 0.2}
},
"hybrid_resnet_se": {
"name": "Hybrid ResNet with SE",
"backbone": {
"name": "resnet50",
"weights": "IMAGENET1K_V2",
"patches": {
"layer3": {"wrap": "se"},
"layer4": {"wrap": "cbam"}
}
},
"head": {"num_classes": 100, "dropout": 0.3}
}
}
[docs]
def get_template(name: str) -> Dict[str, Any]:
"""Get a recipe template by name.
Args:
name: Template name
Returns:
Template configuration dictionary
"""
if name not in RECIPE_TEMPLATES:
raise ValueError(f"Unknown template '{name}'. Available: {list(RECIPE_TEMPLATES.keys())}")
return copy.deepcopy(RECIPE_TEMPLATES[name])