import inspect from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass from functools import lru_cache, wraps from typing import ( Annotated, Any, Callable, ForwardRef, Literal, Optional, Type, TypeVar, Union, get_args, get_origin, overload, ) try: # Python 3.11+ from typing import NotRequired, Required # type: ignore except ImportError: try: # In case typing_extensions is installed from typing_extensions import NotRequired, Required # type: ignore except ImportError: # Fallback: create dummy types that will never match Required = type("Required", (), {}) # type: ignore NotRequired = type("NotRequired", (), {}) # type: ignore from .errors import ( StrictDataclassClassValidationError, StrictDataclassDefinitionError, StrictDataclassFieldValidationError, ) Validator_T = Callable[[Any], None] T = TypeVar("T") TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any]) _TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None) # The overload decorator helps type checkers understand the different return types @overload def strict(cls: Type[T]) -> Type[T]: ... @overload def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ... def strict( cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False ) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: """ Decorator to add strict validation to a dataclass. This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools recognize the class as a dataclass. Can be used with or without arguments: - `@strict` - `@strict(accept_kwargs=True)` Args: cls: The class to convert to a strict dataclass. accept_kwargs (`bool`, *optional*): If True, allows arbitrary keyword arguments in `__init__`. Defaults to False. Returns: The enhanced dataclass with strict validation on field assignment. Example: ```py >>> from dataclasses import dataclass >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field >>> @as_validated_field >>> def positive_int(value: int): ... if not value >= 0: ... raise ValueError(f"Value must be positive, got {value}") >>> @strict(accept_kwargs=True) ... @dataclass ... class User: ... name: str ... age: int = positive_int(default=10) # Initialize >>> User(name="John") User(name='John', age=10) # Extra kwargs are accepted >>> User(name="John", age=30, lastname="Doe") User(name='John', age=30, *lastname='Doe') # Invalid type => raises >>> User(name="John", age="30") huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': TypeError: Field 'age' expected int, got str (value: '30') # Invalid value => raises >>> User(name="John", age=-1) huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': ValueError: Value must be positive, got -1 ``` """ def wrap(cls: Type[T]) -> Type[T]: if not hasattr(cls, "__dataclass_fields__"): raise StrictDataclassDefinitionError( f"Class '{cls.__name__}' must be a dataclass before applying @strict." ) # List and store validators field_validators: dict[str, list[Validator_T]] = {} for f in fields(cls): # type: ignore [arg-type] validators = [] validators.append(_create_type_validator(f)) custom_validator = f.metadata.get("validator") if custom_validator is not None: if not isinstance(custom_validator, list): custom_validator = [custom_validator] for validator in custom_validator: if not _is_validator(validator): raise StrictDataclassDefinitionError( f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument." ) validators.extend(custom_validator) field_validators[f.name] = validators cls.__validators__ = field_validators # type: ignore # Override __setattr__ to validate fields on assignment original_setattr = cls.__setattr__ def __strict_setattr__(self: Any, name: str, value: Any) -> None: """Custom __setattr__ method for strict dataclasses.""" # Run all validators for validator in self.__validators__.get(name, []): try: validator(value) except (ValueError, TypeError) as e: raise StrictDataclassFieldValidationError(field=name, cause=e) from e # If validation passed, set the attribute original_setattr(self, name, value) cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign] if accept_kwargs: # (optional) Override __init__ to accept arbitrary keyword arguments original_init = cls.__init__ @wraps(original_init) def __init__(self, **kwargs: Any) -> None: # Extract only the fields that are part of the dataclass dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type] standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields} # Call the original __init__ with standard fields original_init(self, **standard_kwargs) # Add any additional kwargs as attributes for name, value in kwargs.items(): if name not in dataclass_fields: self.__setattr__(name, value) cls.__init__ = __init__ # type: ignore[method-assign] # (optional) Override __repr__ to include additional kwargs original_repr = cls.__repr__ @wraps(original_repr) def __repr__(self) -> str: # Call the original __repr__ to get the standard fields standard_repr = original_repr(self) # Get additional kwargs additional_kwargs = [ # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass f"*{k}={v!r}" for k, v in self.__dict__.items() if k not in cls.__dataclass_fields__ # type: ignore [attr-defined] ] additional_repr = ", ".join(additional_kwargs) # Combine both representations return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr cls.__repr__ = __repr__ # type: ignore [method-assign] # List all public methods starting with `validate_` => class validators. class_validators = [] for name in dir(cls): if not name.startswith("validate_"): continue method = getattr(cls, name) if not callable(method): continue if len(inspect.signature(method).parameters) != 1: raise StrictDataclassDefinitionError( f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument." " Class validators must take only 'self' as an argument. Methods starting with 'validate_'" " are considered to be class validators." ) class_validators.append(method) cls.__class_validators__ = class_validators # type: ignore [attr-defined] # Add `validate` method to the class, but first check if it already exists def validate(self: T) -> None: """Run class validators on the instance.""" for validator in cls.__class_validators__: # type: ignore [attr-defined] try: validator(self) except (ValueError, TypeError) as e: raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class # (in which case we just override it) validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined] if hasattr(cls, "validate"): if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined] raise StrictDataclassDefinitionError( f"Class '{cls.__name__}' already implements a method called 'validate'." " This method name is reserved when using the @strict decorator on a dataclass." " If you want to keep your own method, please rename it." ) cls.validate = validate # type: ignore # Run class validators after initialization initial_init = cls.__init__ @wraps(initial_init) def init_with_validate(self, *args, **kwargs) -> None: """Run class validators after initialization.""" initial_init(self, *args, **kwargs) # type: ignore [call-arg] cls.validate(self) # type: ignore [attr-defined] setattr(cls, "__init__", init_with_validate) return cls # Return wrapped class or the decorator itself return wrap(cls) if cls is not None else wrap def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None: """ Validate that a dictionary conforms to the types defined in a TypedDict class. Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator. Args: schema (`type[TypedDictType]`): The TypedDict class defining the expected structure and types. data (`dict`): The dictionary to validate. Raises: `StrictDataclassFieldValidationError`: If any field in the dictionary does not conform to the expected type. Example: ```py >>> from typing import Annotated, TypedDict >>> from huggingface_hub.dataclasses import validate_typed_dict >>> def positive_int(value: int): ... if not value >= 0: ... raise ValueError(f"Value must be positive, got {value}") >>> class User(TypedDict): ... name: str ... age: Annotated[int, positive_int] >>> # Valid data >>> validate_typed_dict(User, {"name": "John", "age": 30}) >>> # Invalid type for age >>> validate_typed_dict(User, {"name": "John", "age": "30"}) huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': TypeError: Field 'age' expected int, got str (value: '30') >>> # Invalid value for age >>> validate_typed_dict(User, {"name": "John", "age": -1}) huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age': ValueError: Value must be positive, got -1 ``` """ # Convert typed dict to dataclass strict_cls = _build_strict_cls_from_typed_dict(schema) # Validate the data by instantiating the strict dataclass strict_cls(**data) # will raise if validation fails @lru_cache def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type: # Extract type hints from the TypedDict class type_hints = _get_typed_dict_annotations(schema) # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired) if not getattr(schema, "__total__", True): for key, value in type_hints.items(): origin = get_origin(value) if origin is Annotated: base, *meta = get_args(value) if not _is_required_or_notrequired(base): base = NotRequired[base] type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore elif not _is_required_or_notrequired(value): type_hints[key] = NotRequired[value] # Convert type hints to dataclass fields fields = [] for key, value in type_hints.items(): if get_origin(value) is Annotated: base, *meta = get_args(value) fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]}))) else: fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE))) # Create a strict dataclass from the TypedDict fields return strict(make_dataclass(schema.__name__, fields)) def _get_typed_dict_annotations(schema: type[TypedDictType]) -> dict[str, Any]: """Extract type annotations from a TypedDict class.""" try: # Available in Python 3.14+ import annotationlib return annotationlib.get_annotations(schema) except ImportError: return { # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail). # ForwardRefs are not validated by @strict anyway. name: value if value is not None else type(None) for name, value in schema.__dict__.get("__annotations__", {}).items() } def validated_field( validator: Union[list[Validator_T], Validator_T], default: Union[Any, _MISSING_TYPE] = MISSING, default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[dict] = None, **kwargs: Any, ) -> Any: """ Create a dataclass field with a custom validator. Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator. Args: validator (`Callable` or `list[Callable]`): A method that takes a value as input and raises ValueError/TypeError if the value is invalid. Can be a list of validators to apply multiple checks. **kwargs: Additional arguments to pass to `dataclasses.field()`. Returns: A field with the validator attached in metadata """ if not isinstance(validator, list): validator = [validator] if metadata is None: metadata = {} metadata["validator"] = validator return field( # type: ignore default=default, # type: ignore [arg-type] default_factory=default_factory, # type: ignore [arg-type] init=init, repr=repr, hash=hash, compare=compare, metadata=metadata, **kwargs, ) def as_validated_field(validator: Validator_T): """ Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator). Args: validator (`Callable`): A method that takes a value as input and raises ValueError/TypeError if the value is invalid. """ def _inner( default: Union[Any, _MISSING_TYPE] = MISSING, default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[dict] = None, **kwargs: Any, ): return validated_field( validator, default=default, default_factory=default_factory, init=init, repr=repr, hash=hash, compare=compare, metadata=metadata, **kwargs, ) return _inner def type_validator(name: str, value: Any, expected_type: Any) -> None: """Validate that 'value' matches 'expected_type'.""" origin = get_origin(expected_type) args = get_args(expected_type) if expected_type is Any: return elif validator := _BASIC_TYPE_VALIDATORS.get(origin): validator(name, value, args) elif isinstance(expected_type, type): # simple types _validate_simple_type(name, value, expected_type) elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str): return elif origin is Required: if value is _TYPED_DICT_DEFAULT_VALUE: raise TypeError(f"Field '{name}' is required but missing.") type_validator(name, value, args[0]) elif origin is NotRequired: if value is _TYPED_DICT_DEFAULT_VALUE: return type_validator(name, value, args[0]) else: raise TypeError(f"Unsupported type for field '{name}': {expected_type}") def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate that value matches one of the types in a Union.""" errors = [] for t in args: try: type_validator(name, value, t) return # Valid if any type matches except TypeError as e: errors.append(str(e)) raise TypeError( f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}" ) def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate Literal type.""" if value not in args: raise TypeError(f"Field '{name}' expected one of {args}, got {value}") def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate list[T] type.""" if not isinstance(value, list): raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}") # Validate each item in the list item_type = args[0] for i, item in enumerate(value): try: type_validator(f"{name}[{i}]", item, item_type) except TypeError as e: raise TypeError(f"Invalid item at index {i} in list '{name}'") from e def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate dict[K, V] type.""" if not isinstance(value, dict): raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}") # Validate keys and values key_type, value_type = args for k, v in value.items(): try: type_validator(f"{name}.key", k, key_type) type_validator(f"{name}[{k!r}]", v, value_type) except TypeError as e: raise TypeError(f"Invalid key or value in dict '{name}'") from e def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate Tuple type.""" if not isinstance(value, tuple): raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}") # Handle variable-length tuples: tuple[T, ...] if len(args) == 2 and args[1] is Ellipsis: for i, item in enumerate(value): try: type_validator(f"{name}[{i}]", item, args[0]) except TypeError as e: raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e # Handle fixed-length tuples: tuple[T1, T2, ...] elif len(args) != len(value): raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}") else: for i, (item, expected) in enumerate(zip(value, args)): try: type_validator(f"{name}[{i}]", item, expected) except TypeError as e: raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None: """Validate set[T] type.""" if not isinstance(value, set): raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}") # Validate each item in the set item_type = args[0] for i, item in enumerate(value): try: type_validator(f"{name} item", item, item_type) except TypeError as e: raise TypeError(f"Invalid item in set '{name}'") from e def _validate_simple_type(name: str, value: Any, expected_type: type) -> None: """Validate simple type (int, str, etc.).""" if not isinstance(value, expected_type): raise TypeError( f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})" ) def _create_type_validator(field: Field) -> Validator_T: """Create a type validator function for a field.""" # Hacky: we cannot use a lambda here because of reference issues def validator(value: Any) -> None: type_validator(field.name, value, field.type) return validator def _is_validator(validator: Any) -> bool: """Check if a function is a validator. A validator is a Callable that can be called with a single positional argument. The validator can have more arguments with default values. Basically, returns True if `validator(value)` is possible. """ if not callable(validator): return False signature = inspect.signature(validator) parameters = list(signature.parameters.values()) if len(parameters) == 0: return False if parameters[0].kind not in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL, ): return False for parameter in parameters[1:]: if parameter.default == inspect.Parameter.empty: return False return True def _is_required_or_notrequired(type_hint: Any) -> bool: """Helper to check if a type is Required/NotRequired.""" return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired)) _BASIC_TYPE_VALIDATORS = { Union: _validate_union, Literal: _validate_literal, list: _validate_list, dict: _validate_dict, tuple: _validate_tuple, set: _validate_set, } __all__ = [ "strict", "validate_typed_dict", "validated_field", "Validator_T", "StrictDataclassClassValidationError", "StrictDataclassDefinitionError", "StrictDataclassFieldValidationError", ]