diff --git a/docs/index.md b/docs/index.md index 640cd3a2..e4e0d26a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -90,15 +90,15 @@ print(Settings().model_dump()) 2. The environment variable name is overridden using `alias`. In this case, the environment variable `my_api_key` will be used for both validation and serialization instead of `api_key`. - Check the [`Field` documentation](fields.md#field-aliases) for more information. + Check the [`Field` documentation](fields.md#field-aliases) for more information. -3. The `AliasChoices` class allows to have multiple environment variable names for a single field. +3. The [`AliasChoices`][pydantic.AliasChoices] class allows to have multiple environment variable names for a single field. The first environment variable that is found will be used. - Check the [`AliasChoices`](alias.md#aliaspath-and-aliaschoices) for more information. + Check the [documentation on alias choices](alias.md#aliaspath-and-aliaschoices) for more information. -4. The `ImportString` class allows to import an object from a string. - In this case, the environment variable `special_function` will be read and the function `math.cos` will be imported. +4. The [`ImportString`][pydantic.types.ImportString] class allows to import an object from a string. + In this case, the environment variable `special_function` will be read and the function [`math.cos`][] will be imported. 5. The `env_prefix` config setting allows to set a prefix for all environment variables. @@ -136,7 +136,7 @@ print(Settings1()) #> foo='test' ``` -Check the [Validation of default values](validators.md#validation-of-default-values) for more information. +Check the [validation of default values](fields.md#validate-default-values) for more information. ## Environment variable names @@ -371,6 +371,100 @@ print(Settings().model_dump()) #> {'numbers': [1, 2, 3]} ``` +### Disabling JSON parsing + +pydantic-settings by default parses complex types from environment variables as JSON strings. If you want to disable +this behavior for a field and parse the value in your own validator, you can annotate the field with +[`NoDecode`](../api/pydantic_settings.md#pydantic_settings.NoDecode): + +```py +import os +from typing import List + +from pydantic import field_validator +from typing_extensions import Annotated + +from pydantic_settings import BaseSettings, NoDecode + + +class Settings(BaseSettings): + numbers: Annotated[List[int], NoDecode] # (1)! + + @field_validator('numbers', mode='before') + @classmethod + def decode_numbers(cls, v: str) -> List[int]: + return [int(x) for x in v.split(',')] + + +os.environ['numbers'] = '1,2,3' +print(Settings().model_dump()) +#> {'numbers': [1, 2, 3]} +``` + +1. The `NoDecode` annotation disables JSON parsing for the `numbers` field. The `decode_numbers` field validator + will be called to parse the value. + +You can also disable JSON parsing for all fields by setting the `enable_decoding` config setting to `False`: + +```py +import os +from typing import List + +from pydantic import field_validator + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(enable_decoding=False) + + numbers: List[int] + + @field_validator('numbers', mode='before') + @classmethod + def decode_numbers(cls, v: str) -> List[int]: + return [int(x) for x in v.split(',')] + + +os.environ['numbers'] = '1,2,3' +print(Settings().model_dump()) +#> {'numbers': [1, 2, 3]} +``` + +You can force JSON parsing for a field by annotating it with [`ForceDecode`](../api/pydantic_settings.md#pydantic_settings.ForceDecode). +This will bypass the `enable_decoding` config setting: + +```py +import os +from typing import List + +from pydantic import field_validator +from typing_extensions import Annotated + +from pydantic_settings import BaseSettings, ForceDecode, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(enable_decoding=False) + + numbers: Annotated[List[int], ForceDecode] + numbers1: List[int] # (1)! + + @field_validator('numbers1', mode='before') + @classmethod + def decode_numbers1(cls, v: str) -> List[int]: + return [int(x) for x in v.split(',')] + + +os.environ['numbers'] = '["1","2","3"]' +os.environ['numbers1'] = '1,2,3' +print(Settings().model_dump()) +#> {'numbers': [1, 2, 3], 'numbers1': [1, 2, 3]} +``` + +1. The `numbers1` field is not annotated with `ForceDecode`, so it will not be parsed as JSON. + and we have to provide a custom validator to parse the value. + ## Nested model default partial updates By default, Pydantic settings does not allow partial updates to nested model default objects. This behavior can be @@ -957,17 +1051,51 @@ assert cmd.model_dump() == { For `BaseModel` and `pydantic.dataclasses.dataclass` types, `CliApp.run` will internally use the following `BaseSettings` configuration defaults: -* `alias_generator=AliasGenerator(lambda s: s.replace('_', '-'))` * `nested_model_default_partial_update=True` * `case_sensitive=True` * `cli_hide_none_type=True` * `cli_avoid_json=True` * `cli_enforce_required=True` * `cli_implicit_flags=True` +* `cli_kebab_case=True` + +### Mutually Exclusive Groups + +CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class. !!! note - The alias generator for kebab case does not propagate to subcommands or submodels and will have to be manually set - in these cases. + A `CliMutuallyExclusiveGroup` cannot be used in a union or contain nested models. + +```py +from typing import Optional + +from pydantic import BaseModel + +from pydantic_settings import CliApp, CliMutuallyExclusiveGroup, SettingsError + + +class Circle(CliMutuallyExclusiveGroup): + radius: Optional[float] = None + diameter: Optional[float] = None + perimeter: Optional[float] = None + + +class Settings(BaseModel): + circle: Circle + + +try: + CliApp.run( + Settings, + cli_args=['--circle.radius=1', '--circle.diameter=2'], + cli_exit_on_error=False, + ) +except SettingsError as e: + print(e) + """ + error parsing CLI: argument --circle.diameter: not allowed with argument --circle.radius + """ +``` ### Customizing the CLI Experience @@ -1093,6 +1221,37 @@ print(Settings().model_dump()) #> {'good_arg': 'hello world'} ``` +#### CLI Kebab Case for Arguments + +Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`. + +```py +import sys + +from pydantic import Field + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings, cli_parse_args=True, cli_kebab_case=True): + my_option: str = Field(description='will show as kebab case on CLI') + + +try: + sys.argv = ['example.py', '--help'] + Settings() +except SystemExit as e: + print(e) + #> 0 +""" +usage: example.py [-h] [--my-option str] + +options: + -h, --help show this help message and exit + --my-option str will show as kebab case on CLI (required) +""" +``` + #### Change Whether CLI Should Exit on Error Change whether the CLI internal parser will exit on error or raise a `SettingsError` exception by using diff --git a/pydantic_settings/__init__.py b/pydantic_settings/__init__.py index fd42d3ef..0a02868c 100644 --- a/pydantic_settings/__init__.py +++ b/pydantic_settings/__init__.py @@ -4,14 +4,17 @@ AzureKeyVaultSettingsSource, CliExplicitFlag, CliImplicitFlag, + CliMutuallyExclusiveGroup, CliPositionalArg, CliSettingsSource, CliSubCommand, CliSuppress, DotEnvSettingsSource, EnvSettingsSource, + ForceDecode, InitSettingsSource, JsonConfigSettingsSource, + NoDecode, PydanticBaseSettingsSource, PyprojectTomlConfigSettingsSource, SecretsSettingsSource, @@ -34,8 +37,11 @@ 'CliPositionalArg', 'CliExplicitFlag', 'CliImplicitFlag', + 'CliMutuallyExclusiveGroup', 'InitSettingsSource', 'JsonConfigSettingsSource', + 'NoDecode', + 'ForceDecode', 'PyprojectTomlConfigSettingsSource', 'PydanticBaseSettingsSource', 'SecretsSettingsSource', diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 723d6d50..f376361e 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -4,7 +4,7 @@ from types import SimpleNamespace from typing import Any, ClassVar, TypeVar -from pydantic import AliasGenerator, ConfigDict +from pydantic import ConfigDict from pydantic._internal._config import config_keys from pydantic._internal._signature import _field_name_for_signature from pydantic._internal._utils import deep_update, is_model_class @@ -52,6 +52,7 @@ class SettingsConfigDict(ConfigDict, total=False): cli_flag_prefix_char: str cli_implicit_flags: bool | None cli_ignore_unknown_args: bool | None + cli_kebab_case: bool | None secrets_dir: PathType | None json_file: PathType | None json_file_encoding: str | None @@ -78,6 +79,7 @@ class SettingsConfigDict(ConfigDict, total=False): """ toml_file: PathType | None + enable_decoding: bool # Extend `config_keys` by pydantic settings config keys to @@ -133,6 +135,7 @@ class BaseSettings(BaseModel): _cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags. (e.g. --flag, --no-flag). Defaults to `False`. _cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`. + _cli_kebab_case: CLI args use kebab case. Defaults to `False`. _secrets_dir: The secret files directory or a sequence of directories. Defaults to `None`. """ @@ -160,6 +163,7 @@ def __init__( _cli_flag_prefix_char: str | None = None, _cli_implicit_flags: bool | None = None, _cli_ignore_unknown_args: bool | None = None, + _cli_kebab_case: bool | None = None, _secrets_dir: PathType | None = None, **values: Any, ) -> None: @@ -189,6 +193,7 @@ def __init__( _cli_flag_prefix_char=_cli_flag_prefix_char, _cli_implicit_flags=_cli_implicit_flags, _cli_ignore_unknown_args=_cli_ignore_unknown_args, + _cli_kebab_case=_cli_kebab_case, _secrets_dir=_secrets_dir, ) ) @@ -242,6 +247,7 @@ def _settings_build_values( _cli_flag_prefix_char: str | None = None, _cli_implicit_flags: bool | None = None, _cli_ignore_unknown_args: bool | None = None, + _cli_kebab_case: bool | None = None, _secrets_dir: PathType | None = None, ) -> dict[str, Any]: # Determine settings config values @@ -309,6 +315,7 @@ def _settings_build_values( if _cli_ignore_unknown_args is not None else self.model_config.get('cli_ignore_unknown_args') ) + cli_kebab_case = _cli_kebab_case if _cli_kebab_case is not None else self.model_config.get('cli_kebab_case') secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir') @@ -371,6 +378,7 @@ def _settings_build_values( cli_flag_prefix_char=cli_flag_prefix_char, cli_implicit_flags=cli_implicit_flags, cli_ignore_unknown_args=cli_ignore_unknown_args, + cli_kebab_case=cli_kebab_case, case_sensitive=case_sensitive, ) sources = (cli_settings,) + sources @@ -418,13 +426,15 @@ def _settings_build_values( cli_flag_prefix_char='-', cli_implicit_flags=False, cli_ignore_unknown_args=False, + cli_kebab_case=False, json_file=None, json_file_encoding=None, yaml_file=None, yaml_file_encoding=None, toml_file=None, secrets_dir=None, - protected_namespaces=('model_', 'settings_'), + protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'), + enable_decoding=True, ) @@ -497,13 +507,13 @@ def run( class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore model_config = SettingsConfigDict( - alias_generator=AliasGenerator(lambda s: s.replace('_', '-')), nested_model_default_partial_update=True, case_sensitive=True, cli_hide_none_type=True, cli_avoid_json=True, cli_enforce_required=True, cli_implicit_flags=True, + cli_kebab_case=True, ) model = CliAppBaseSettings(**model_init_data) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index e0e08099..656a32f1 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -37,9 +37,8 @@ import typing_extensions from dotenv import dotenv_values -from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, TypeAdapter +from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, Secret, TypeAdapter from pydantic._internal._repr import Representation -from pydantic._internal._signature import _field_name_for_signature from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass from pydantic.dataclasses import is_pydantic_dataclass @@ -119,6 +118,18 @@ def import_azure_key_vault() -> None: ENV_FILE_SENTINEL: DotenvType = Path('') +class NoDecode: + """Annotation to prevent decoding of a field value.""" + + pass + + +class ForceDecode: + """Annotation to force decoding of a field value.""" + + pass + + class SettingsError(ValueError): pass @@ -150,6 +161,10 @@ def error(self, message: str) -> NoReturn: super().error(message) +class CliMutuallyExclusiveGroup(BaseModel): + pass + + T = TypeVar('T') CliSubCommand = Annotated[Union[T, None], _CliSubCommand] CliPositionalArg = Annotated[T, _CliPositionalArg] @@ -309,6 +324,12 @@ def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Returns: The decoded value for further preparation """ + if field and ( + NoDecode in field.metadata + or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata) + ): + return value + return json.loads(value) @abstractmethod @@ -336,10 +357,12 @@ def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partia ) if self.nested_model_default_partial_update: for field_name, field_info in settings_cls.model_fields.items(): + alias_names, *_ = _get_alias_names(field_name, field_info) + preferred_alias = alias_names[0] if is_dataclass(type(field_info.default)): - self.defaults[_field_name_for_signature(field_name, field_info)] = asdict(field_info.default) + self.defaults[preferred_alias] = asdict(field_info.default) elif is_model_class(type(field_info.default)): - self.defaults[_field_name_for_signature(field_name, field_info)] = field_info.default.model_dump() + self.defaults[preferred_alias] = field_info.default.model_dump() def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: # Nothing to do here. Only implement the return statement to make mypy happy @@ -349,7 +372,9 @@ def __call__(self) -> dict[str, Any]: return self.defaults def __repr__(self) -> str: - return f'DefaultSettingsSource(nested_model_default_partial_update={self.nested_model_default_partial_update})' + return ( + f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})' + ) class InitSettingsSource(PydanticBaseSettingsSource): @@ -383,7 +408,7 @@ def __call__(self) -> dict[str, Any]: ) def __repr__(self) -> str: - return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' + return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})' class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource): @@ -654,7 +679,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, a flag to determine whether value is complex. """ - for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): + field_infos = self._extract_field_info(field, field_name) + preferred_key, *_ = field_infos[0] + for field_key, env_name, value_is_complex in field_infos: # paths reversed to match the last-wins behaviour of `env_file` for secrets_path in reversed(self.secrets_paths): path = self.find_case_path(secrets_path, env_name, self.case_sensitive) @@ -663,17 +690,19 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, continue if path.is_file(): - return path.read_text().strip(), field_key, value_is_complex + if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)): + preferred_key = field_key + return path.read_text().strip(), preferred_key, value_is_complex else: warnings.warn( f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.', stacklevel=4, ) - return None, field_key, value_is_complex + return None, preferred_key, value_is_complex def __repr__(self) -> str: - return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})' + return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})' class EnvSettingsSource(PydanticBaseEnvSettingsSource): @@ -718,12 +747,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, """ env_val: str | None = None - for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): + field_infos = self._extract_field_info(field, field_name) + preferred_key, *_ = field_infos[0] + for field_key, env_name, value_is_complex in field_infos: env_val = self.env_vars.get(env_name) if env_val is not None: + if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)): + preferred_key = field_key break - return env_val, field_key, value_is_complex + return env_val, preferred_key, value_is_complex def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: """ @@ -898,7 +931,7 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ def __repr__(self) -> str: return ( - f'EnvSettingsSource(env_nested_delimiter={self.env_nested_delimiter!r}, ' + f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, ' f'env_prefix_len={self.env_prefix_len!r})' ) @@ -1014,7 +1047,7 @@ def __call__(self) -> dict[str, Any]: def __repr__(self) -> str: return ( - f'DotEnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' + f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})' ) @@ -1048,6 +1081,7 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]): cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags. (e.g. --flag, --no-flag). Defaults to `False`. cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`. + cli_kebab_case: CLI args use kebab case. Defaults to `False`. case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`. Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI subcommands. @@ -1078,6 +1112,7 @@ def __init__( cli_flag_prefix_char: str | None = None, cli_implicit_flags: bool | None = None, cli_ignore_unknown_args: bool | None = None, + cli_kebab_case: bool | None = None, case_sensitive: bool | None = True, root_parser: Any = None, parse_args_method: Callable[..., Any] | None = None, @@ -1137,6 +1172,9 @@ def __init__( if cli_ignore_unknown_args is not None else settings_cls.model_config.get('cli_ignore_unknown_args', False) ) + self.cli_kebab_case = ( + cli_kebab_case if cli_kebab_case is not None else settings_cls.model_config.get('cli_kebab_case', False) + ) case_sensitive = case_sensitive if case_sensitive is not None else True if not case_sensitive and root_parser is not None: @@ -1158,6 +1196,7 @@ def __init__( description=None if settings_cls.__doc__ is None else dedent(settings_cls.__doc__), formatter_class=formatter_class, prefix_chars=self.cli_flag_prefix_char, + allow_abbrev=False, ) if root_parser is None else root_parser @@ -1418,45 +1457,10 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F raise SettingsError(f'CliSubCommand is not outermost annotation for {model.__name__}.{field_name}') elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False): raise SettingsError(f'CliPositionalArg is not outermost annotation for {model.__name__}.{field_name}') - if is_model_class(type_) or is_pydantic_dataclass(type_): - sub_models.append(type_) # type: ignore + if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)): + sub_models.append(_strip_annotated(type_)) return sub_models - def _get_alias_names( - self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] - ) -> tuple[tuple[str, ...], bool]: - alias_names: list[str] = [] - is_alias_path_only: bool = True - if not any((field_info.alias, field_info.validation_alias)): - alias_names += [field_name] - is_alias_path_only = False - else: - new_alias_paths: list[AliasPath] = [] - for alias in (field_info.alias, field_info.validation_alias): - if alias is None: - continue - elif isinstance(alias, str): - alias_names.append(alias) - is_alias_path_only = False - elif isinstance(alias, AliasChoices): - for name in alias.choices: - if isinstance(name, str): - alias_names.append(name) - is_alias_path_only = False - else: - new_alias_paths.append(name) - else: - new_alias_paths.append(alias) - for alias_path in new_alias_paths: - name = cast(str, alias_path.path[0]) - name = name.lower() if not self.case_sensitive else name - alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list' - if not alias_names and is_alias_path_only: - alias_names.append(name) - if not self.case_sensitive: - alias_names = [alias_name.lower() for alias_name in alias_names] - return tuple(dict.fromkeys(alias_names)), is_alias_path_only - def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None: if _CliImplicitFlag in field_info.metadata: cli_flag_name = 'CliImplicitFlag' @@ -1481,7 +1485,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] if not field_info.is_required(): raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value') else: - alias_names, *_ = self._get_alias_names(field_name, field_info, {}) + alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases') field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)] @@ -1495,7 +1499,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] if not field_info.is_required(): raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value') else: - alias_names, *_ = self._get_alias_names(field_name, field_info, {}) + alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') positional_args.append((field_name, field_info)) @@ -1515,7 +1519,7 @@ def _connect_parser_method( if ( parser_method is not None and self.case_sensitive is False - and method_name == 'parsed_args_method' + and method_name == 'parse_args_method' and isinstance(self._root_parser, _CliInternalArgParser) ): @@ -1547,6 +1551,26 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any: else: return parser_method + def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]: + add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method') + + def add_group_method(parser: Any, **kwargs: Any) -> Any: + if not kwargs.pop('_is_cli_mutually_exclusive_group'): + kwargs.pop('required') + return add_argument_group(parser, **kwargs) + else: + main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs} + main_group_kwargs['title'] += ' (mutually exclusive)' + group = add_argument_group(parser, **main_group_kwargs) + if not hasattr(group, 'add_mutually_exclusive_group'): + raise SettingsError( + 'cannot connect CLI settings source root parser: ' + 'group object is missing add_mutually_exclusive_group but is needed for connecting' + ) + return group.add_mutually_exclusive_group(**kwargs) + + return add_group_method + def _connect_root_parser( self, root_parser: T, @@ -1563,9 +1587,9 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace: self._root_parser = root_parser if parse_args_method is None: parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args - self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method') + self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method') self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method') - self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method') + self._add_group = self._connect_group_method(add_argument_group_method) self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method') self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method') self._formatter_class = formatter_class @@ -1595,13 +1619,26 @@ def _add_parser_args( ) -> ArgumentParser: subparsers: Any = None alias_path_args: dict[str, str] = {} + # Ignore model default if the default is a model and not a subclass of the current model. + model_default = ( + None + if ( + (is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default))) + and not issubclass(type(model_default), model) + ) + else model_default + ) for field_name, field_info in self._sort_arg_fields(model): sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info) - alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args) + alias_names, is_alias_path_only = _get_alias_names( + field_name, field_info, alias_path_args=alias_path_args, case_sensitive=self.case_sensitive + ) preferred_alias = alias_names[0] if _CliSubCommand in field_info.metadata: for model in sub_models: - subcommand_alias = model.__name__ if len(sub_models) > 1 else preferred_alias + subcommand_alias = self._check_kebab_name( + model.__name__ if len(sub_models) > 1 else preferred_alias + ) subcommand_name = f'{arg_prefix}{subcommand_alias}' subcommand_dest = f'{arg_prefix}{preferred_alias}' self._cli_subcommands[f'{arg_prefix}:subcommand'][subcommand_name] = subcommand_dest @@ -1665,7 +1702,8 @@ def _add_parser_args( else f'{arg_prefix}{preferred_alias}' ) - if kwargs['dest'] in added_args: + arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, alias_names, added_args) + if not arg_names or (kwargs['dest'] in added_args): continue if is_append_action: @@ -1673,9 +1711,8 @@ def _add_parser_args( if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_strip_annotated=True): self._cli_dict_args[kwargs['dest']] = field_info.annotation - arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, alias_names) if _CliPositionalArg in field_info.metadata: - kwargs['metavar'] = preferred_alias.upper() + kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) arg_names = [kwargs['dest']] del kwargs['dest'] del kwargs['required'] @@ -1686,6 +1723,7 @@ def _add_parser_args( if is_parser_submodel: self._add_parser_submodels( parser, + model, sub_models, added_args, arg_prefix, @@ -1701,7 +1739,7 @@ def _add_parser_args( elif not is_alias_path_only: if group is not None: if isinstance(group, dict): - group = self._add_argument_group(parser, **group) + group = self._add_group(parser, **group) added_args += list(arg_names) self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs) else: @@ -1713,6 +1751,11 @@ def _add_parser_args( self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group) return parser + def _check_kebab_name(self, name: str) -> str: + if self.cli_kebab_case: + return name.replace('_', '-') + return name + def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None: if kwargs['metavar'] == 'bool': default = None @@ -1730,21 +1773,29 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode ) def _get_arg_names( - self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], alias_names: tuple[str, ...] + self, + arg_prefix: str, + subcommand_prefix: str, + alias_prefixes: list[str], + alias_names: tuple[str, ...], + added_args: list[str], ) -> list[str]: arg_names: list[str] = [] for prefix in [arg_prefix] + alias_prefixes: for name in alias_names: - arg_names.append( + arg_name = self._check_kebab_name( f'{prefix}{name}' if subcommand_prefix == self.env_prefix else f'{prefix.replace(subcommand_prefix, "", 1)}{name}' ) + if arg_name not in added_args: + arg_names.append(arg_name) return arg_names def _add_parser_submodels( self, parser: Any, + model: type[BaseModel], sub_models: list[type[BaseModel]], added_args: list[str], arg_prefix: str, @@ -1757,10 +1808,23 @@ def _add_parser_submodels( alias_names: tuple[str, ...], model_default: Any, ) -> None: + if issubclass(model, CliMutuallyExclusiveGroup): + # Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a + # mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion). + # Since nested models result in a group add, raise an exception for nested models in a mutually + # exclusive group. + raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup') + model_group: Any = None model_group_kwargs: dict[str, Any] = {} model_group_kwargs['title'] = f'{arg_names[0]} options' model_group_kwargs['description'] = field_info.description + model_group_kwargs['required'] = kwargs['required'] + model_group_kwargs['_is_cli_mutually_exclusive_group'] = any( + issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models + ) + if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1: + raise SettingsError('cannot use union with CliMutuallyExclusiveGroup') if self.cli_use_class_docs_for_groups and len(sub_models) == 1: model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__) @@ -1783,7 +1847,7 @@ def _add_parser_submodels( if not self.cli_avoid_json: added_args.append(arg_names[0]) kwargs['help'] = f'set {arg_names[0]} from JSON string' - model_group = self._add_argument_group(parser, **model_group_kwargs) + model_group = self._add_group(parser, **model_group_kwargs) self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs) for model in sub_models: self._add_parser_args( @@ -1809,7 +1873,7 @@ def _add_parser_alias_paths( if alias_path_args: context = parser if group is not None: - context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group + context = self._add_group(parser, **group) if isinstance(group, dict) else group is_nested_alias_path = arg_prefix.endswith('.') arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix for name, metavar in alias_path_args.items(): @@ -1871,7 +1935,8 @@ def _metavar_format_recurse(self, obj: Any) -> str: return self._metavar_format_choices([val.name for val in obj]) elif isinstance(obj, WithArgsTypes): return self._metavar_format_choices( - list(map(self._metavar_format_recurse, self._get_modified_args(obj))), obj_qualname=obj.__qualname__ + list(map(self._metavar_format_recurse, self._get_modified_args(obj))), + obj_qualname=obj.__qualname__ if hasattr(obj, '__qualname__') else str(obj), ) elif obj is type(None): return self.cli_parse_none_str @@ -1951,6 +2016,9 @@ def _read_file(self, file_path: Path) -> dict[str, Any]: with open(file_path, encoding=self.json_file_encoding) as json_file: return json.load(json_file) + def __repr__(self) -> str: + return f'{self.__class__.__name__}(json_file={self.json_file_path})' + class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin): """ @@ -1973,6 +2041,9 @@ def _read_file(self, file_path: Path) -> dict[str, Any]: return tomli.load(toml_file) return tomllib.load(toml_file) + def __repr__(self) -> str: + return f'{self.__class__.__name__}(toml_file={self.toml_file_path})' + class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource): """ @@ -2045,6 +2116,9 @@ def _read_file(self, file_path: Path) -> dict[str, Any]: with open(file_path, encoding=self.yaml_file_encoding) as yaml_file: return yaml.safe_load(yaml_file) or {} + def __repr__(self) -> str: + return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})' + class AzureKeyVaultMapping(Mapping[str, Optional[str]]): _loaded_secrets: dict[str, str | None] @@ -2107,7 +2181,7 @@ def _load_env_vars(self) -> Mapping[str, Optional[str]]: return AzureKeyVaultMapping(secret_client) def __repr__(self) -> str: - return f'AzureKeyVaultSettingsSource(url={self._url!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r})' + return f'{self.__class__.__name__}(url={self._url!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r})' def _get_env_var_key(key: str, case_sensitive: bool = False) -> str: @@ -2174,6 +2248,10 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> inner, *meta = get_args(annotation) return _annotation_is_complex(inner, meta) origin = get_origin(annotation) + + if origin is Secret: + return False + return ( _annotation_is_complex_inner(annotation) or _annotation_is_complex_inner(origin) @@ -2241,5 +2319,41 @@ def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]: raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass') +def _get_alias_names( + field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] = {}, case_sensitive: bool = True +) -> tuple[tuple[str, ...], bool]: + alias_names: list[str] = [] + is_alias_path_only: bool = True + if not any((field_info.alias, field_info.validation_alias)): + alias_names += [field_name] + is_alias_path_only = False + else: + new_alias_paths: list[AliasPath] = [] + for alias in (field_info.alias, field_info.validation_alias): + if alias is None: + continue + elif isinstance(alias, str): + alias_names.append(alias) + is_alias_path_only = False + elif isinstance(alias, AliasChoices): + for name in alias.choices: + if isinstance(name, str): + alias_names.append(name) + is_alias_path_only = False + else: + new_alias_paths.append(name) + else: + new_alias_paths.append(alias) + for alias_path in new_alias_paths: + name = cast(str, alias_path.path[0]) + name = name.lower() if not case_sensitive else name + alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list' + if not alias_names and is_alias_path_only: + alias_names.append(name) + if not case_sensitive: + alias_names = [alias_name.lower() for alias_name in alias_names] + return tuple(dict.fromkeys(alias_names)), is_alias_path_only + + def _is_function(obj: Any) -> bool: return isinstance(obj, (FunctionType, BuiltinFunctionType)) diff --git a/pydantic_settings/version.py b/pydantic_settings/version.py index bff5353e..9ad573dc 100644 --- a/pydantic_settings/version.py +++ b/pydantic_settings/version.py @@ -1 +1 @@ -VERSION = '2.6.1' +VERSION = '2.7.0' diff --git a/tests/conftest.py b/tests/conftest.py index 2118f9dc..7a968c57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,3 +82,15 @@ def docs_test_env(): yield setenv setenv.clear() + + +@pytest.fixture +def cli_test_env(): + setenv = SetEnv() + + # envs for reproducible cli tests + setenv.set('COLUMNS', '80') + + yield setenv + + setenv.clear() diff --git a/tests/test_settings.py b/tests/test_settings.py index 937047b2..2a6578ba 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,4 +1,5 @@ import dataclasses +import json import os import pathlib import sys @@ -7,21 +8,26 @@ from enum import IntEnum from pathlib import Path from typing import Any, Callable, Dict, Generic, Hashable, List, Optional, Set, Tuple, Type, TypeVar, Union +from unittest import mock import pytest from annotated_types import MinLen from pydantic import ( AliasChoices, + AliasGenerator, AliasPath, BaseModel, Discriminator, Field, HttpUrl, Json, + PostgresDsn, RootModel, + Secret, SecretStr, Tag, ValidationError, + field_validator, ) from pydantic import ( dataclasses as pydantic_dataclasses, @@ -33,7 +39,9 @@ BaseSettings, DotEnvSettingsSource, EnvSettingsSource, + ForceDecode, InitSettingsSource, + NoDecode, PydanticBaseSettingsSource, SecretsSettingsSource, SettingsConfigDict, @@ -68,6 +76,12 @@ class SettingWithPopulateByName(BaseSettings): model_config = SettingsConfigDict(populate_by_name=True) +@pytest.fixture(autouse=True) +def clean_env(): + with mock.patch.dict(os.environ, clear=True): + yield + + def test_sub_env(env): env.set('apple', 'hello') s = SimpleSettings() @@ -614,6 +628,26 @@ def settings_customise_sources( assert s.model_dump() == s_final +def test_alias_nested_model_default_partial_update(): + class SubModel(BaseModel): + v1: str = 'default' + v2: bytes = b'hello' + v3: int + + class Settings(BaseSettings): + model_config = SettingsConfigDict( + nested_model_default_partial_update=True, alias_generator=AliasGenerator(lambda s: s.replace('_', '-')) + ) + + v0: str = 'ok' + sub_model: SubModel = SubModel(v1='top default', v3=33) + + assert Settings(**{'sub-model': {'v1': 'cli'}}).model_dump() == { + 'v0': 'ok', + 'sub_model': {'v1': 'cli', 'v2': b'hello', 'v3': 33}, + } + + def test_env_str(env): class Settings(BaseSettings): apple: str = Field(None, validation_alias='BOOM') @@ -1109,9 +1143,13 @@ class Settings(BaseSettings): @pytest.fixture -def home_tmp(): +def home_tmp(tmp_path, env): + env.set('HOME', str(tmp_path)) + env.set('USERPROFILE', str(tmp_path)) + env.set('HOMEPATH', str(tmp_path)) + tmp_filename = f'{uuid.uuid4()}.env' - home_tmp_path = Path.home() / tmp_filename + home_tmp_path = tmp_path / tmp_filename yield home_tmp_path, tmp_filename home_tmp_path.unlink() @@ -2168,26 +2206,28 @@ class Settings(BaseSettings): def test_protected_namespace_defaults(): # pydantic default with pytest.warns( - UserWarning, match='Field "model_prefixed_field" in Model has conflict with protected namespace "model_"' + UserWarning, + match='Field "model_dump_prefixed_field" in Model has conflict with protected namespace "model_dump"', ): class Model(BaseSettings): - model_prefixed_field: str + model_dump_prefixed_field: str # pydantic-settings default - with pytest.raises( - UserWarning, match='Field "settings_prefixed_field" in Model1 has conflict with protected namespace "settings_"' + with pytest.warns( + UserWarning, + match='Field "settings_customise_sources_prefixed_field" in Model1 has conflict with protected namespace "settings_customise_sources"', ): class Model1(BaseSettings): - settings_prefixed_field: str + settings_customise_sources_prefixed_field: str with pytest.raises( NameError, match=( 'Field "settings_customise_sources" conflicts with member > " - 'of protected namespace "settings_".' + 'of protected namespace "settings_customise_sources".' ), ): @@ -2824,3 +2864,82 @@ class Settings(BaseSettings): s = Settings() assert s.model_dump() == {'foo': 'test-foo'} + + +def test_parsing_secret_field(env): + class Settings(BaseSettings): + foo: Secret[int] + bar: Secret[PostgresDsn] + + env.set('foo', '123') + env.set('bar', 'postgres://user:password@localhost/dbname') + + s = Settings() + assert s.foo.get_secret_value() == 123 + assert s.bar.get_secret_value() == PostgresDsn('postgres://user:password@localhost/dbname') + + +def test_field_annotated_no_decode(env): + class Settings(BaseSettings): + a: List[str] # this field will be decoded because of default `enable_decoding=True` + b: Annotated[List[str], NoDecode] + + # decode the value here. the field value won't be decoded because of NoDecode + @field_validator('b', mode='before') + @classmethod + def decode_b(cls, v: str) -> List[str]: + return json.loads(v) + + env.set('a', '["one", "two"]') + env.set('b', '["1", "2"]') + + s = Settings() + assert s.model_dump() == {'a': ['one', 'two'], 'b': ['1', '2']} + + +def test_field_annotated_no_decode_and_disable_decoding(env): + class Settings(BaseSettings): + model_config = SettingsConfigDict(enable_decoding=False) + + a: Annotated[List[str], NoDecode] + + # decode the value here. the field value won't be decoded because of NoDecode + @field_validator('a', mode='before') + @classmethod + def decode_b(cls, v: str) -> List[str]: + return json.loads(v) + + env.set('a', '["one", "two"]') + + s = Settings() + assert s.model_dump() == {'a': ['one', 'two']} + + +def test_field_annotated_disable_decoding(env): + class Settings(BaseSettings): + model_config = SettingsConfigDict(enable_decoding=False) + + a: List[str] + + # decode the value here. the field value won't be decoded because of `enable_decoding=False` + @field_validator('a', mode='before') + @classmethod + def decode_b(cls, v: str) -> List[str]: + return json.loads(v) + + env.set('a', '["one", "two"]') + + s = Settings() + assert s.model_dump() == {'a': ['one', 'two']} + + +def test_field_annotated_force_decode_disable_decoding(env): + class Settings(BaseSettings): + model_config = SettingsConfigDict(enable_decoding=False) + + a: Annotated[List[str], ForceDecode] + + env.set('a', '["one", "two"]') + + s = Settings() + assert s.model_dump() == {'a': ['one', 'two']} diff --git a/tests/test_source_azure_key_vault.py b/tests/test_source_azure_key_vault.py index bd4ddaaa..7e9b203a 100644 --- a/tests/test_source_azure_key_vault.py +++ b/tests/test_source_azure_key_vault.py @@ -115,7 +115,7 @@ def settings_customise_sources( assert settings.sql_server_user == expected_secret_value assert settings.sql_server.password == expected_secret_value - def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name: str) -> KeyVaultSecret: + def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name: str): expected_secret_value = 'SecretValue' key_vault_secret = KeyVaultSecret(SecretProperties(), expected_secret_value) diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 7101d687..35bfcdac 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -10,11 +10,14 @@ import typing_extensions from pydantic import ( AliasChoices, + AliasGenerator, AliasPath, BaseModel, ConfigDict, DirectoryPath, + Discriminator, Field, + Tag, ValidationError, ) from pydantic import ( @@ -33,6 +36,7 @@ CLI_SUPPRESS, CliExplicitFlag, CliImplicitFlag, + CliMutuallyExclusiveGroup, CliPositionalArg, CliSettingsSource, CliSubCommand, @@ -44,6 +48,11 @@ ARGPARSE_OPTIONS_TEXT = 'options' if sys.version_info >= (3, 10) else 'optional arguments' +@pytest.fixture(autouse=True) +def cli_test_env_autouse(cli_test_env): + pass + + def foobar(a, b, c=4): pass @@ -74,34 +83,34 @@ class SettingWithIgnoreEmpty(BaseSettings): class CliDummyArgGroup(BaseModel, arbitrary_types_allowed=True): group: argparse._ArgumentGroup - def add_argument(self, *args, **kwargs) -> None: + def add_argument(self, *args: Any, **kwargs: Any) -> None: self.group.add_argument(*args, **kwargs) class CliDummySubParsers(BaseModel, arbitrary_types_allowed=True): sub_parser: argparse._SubParsersAction - def add_parser(self, *args, **kwargs) -> 'CliDummyParser': + def add_parser(self, *args: Any, **kwargs: Any) -> 'CliDummyParser': return CliDummyParser(parser=self.sub_parser.add_parser(*args, **kwargs)) class CliDummyParser(BaseModel, arbitrary_types_allowed=True): parser: argparse.ArgumentParser = Field(default_factory=lambda: argparse.ArgumentParser()) - def add_argument(self, *args, **kwargs) -> None: + def add_argument(self, *args: Any, **kwargs: Any) -> None: self.parser.add_argument(*args, **kwargs) - def add_argument_group(self, *args, **kwargs) -> CliDummyArgGroup: + def add_argument_group(self, *args: Any, **kwargs: Any) -> CliDummyArgGroup: return CliDummyArgGroup(group=self.parser.add_argument_group(*args, **kwargs)) - def add_subparsers(self, *args, **kwargs) -> CliDummySubParsers: + def add_subparsers(self, *args: Any, **kwargs: Any) -> CliDummySubParsers: return CliDummySubParsers(sub_parser=self.parser.add_subparsers(*args, **kwargs)) - def parse_args(self, *args, **kwargs) -> argparse.Namespace: + def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: return self.parser.parse_args(*args, **kwargs) -def test_validation_alias_with_cli_prefix(): +def test_cli_validation_alias_with_cli_prefix(): class Settings(BaseSettings, cli_exit_on_error=False): foobar: str = Field(validation_alias='foo') @@ -113,6 +122,36 @@ class Settings(BaseSettings, cli_exit_on_error=False): assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar' +@pytest.mark.parametrize( + 'alias_generator', + [ + AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))), + AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)), + ], +) +def test_cli_alias_resolution_consistency_with_env(env, alias_generator): + class SubModel(BaseModel): + v1: str = 'model default' + + class Settings(BaseSettings): + model_config = SettingsConfigDict( + env_nested_delimiter='__', + nested_model_default_partial_update=True, + alias_generator=alias_generator, + ) + + sub_model: SubModel = SubModel(v1='top default') + + assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}} + + env.set('SUB_MODEL__V1', 'env default') + assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}} + + assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == { + 'sub_model': {'v1': 'cli default'} + } + + def test_cli_nested_arg(): class SubSubValue(BaseModel): v6: str @@ -442,6 +481,46 @@ class MultilineDoc(BaseSettings, cli_parse_args=True): ) +def test_cli_help_union_of_models(capsys, monkeypatch): + class Cat(BaseModel): + meow: str = 'meow' + + class Dog(BaseModel): + bark: str = 'bark' + + class Bird(BaseModel): + caww: str = 'caww' + tweet: str + + class Tiger(Cat): + roar: str = 'roar' + + class Car(BaseSettings, cli_parse_args=True): + driver: Union[Cat, Dog, Bird] = Tiger(meow='purr') + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + + with pytest.raises(SystemExit): + Car() + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] [--driver JSON] [--driver.meow str] [--driver.bark str] + [--driver.caww str] [--driver.tweet str] + +{ARGPARSE_OPTIONS_TEXT}: + -h, --help show this help message and exit + +driver options: + --driver JSON set driver from JSON string + --driver.meow str (default: purr) + --driver.bark str (default: bark) + --driver.caww str (default: caww) + --driver.tweet str (ifdef: required) +""" + ) + + def test_cli_help_default_or_none_model(capsys, monkeypatch): class DeeperSubModel(BaseModel): flag: bool @@ -1781,11 +1860,11 @@ class Cfg(BaseSettings): args = ['--fruit', 'pear'] parsed_args = parser.parse_args(args) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'bird', 'command': None, } - assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'bird', 'command': None, } @@ -1793,28 +1872,28 @@ class Cfg(BaseSettings): arg_prefix = f'{prefix}.' if prefix else '' args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog'] parsed_args = parser.parse_args(args) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': None, } - assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': None, } parsed_args = parser.parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat']) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == { + assert CliApp.run(Cfg, cli_args=vars(parsed_args), cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'cat', 'command': None, } args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog', 'command', '--name', 'ralph', '--command', 'roll'] parsed_args = parser.parse_args(args) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == { + assert CliApp.run(Cfg, cli_args=vars(parsed_args), cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': {'name': 'ralph', 'command': 'roll'}, } - assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': {'name': 'ralph', 'command': 'roll'}, } @@ -2040,3 +2119,259 @@ class Settings(BaseSettings, cli_parse_args=True): -h, --help show this help message and exit """ ) + + +def test_cli_mutually_exclusive_group(capsys, monkeypatch): + class Circle(CliMutuallyExclusiveGroup): + radius: Optional[float] = 21 + diameter: Optional[float] = 22 + perimeter: Optional[float] = 23 + + class Settings(BaseModel): + circle_optional: Circle = Circle(radius=None, diameter=None, perimeter=24) + circle_required: Circle + + CliApp.run(Settings, cli_args=['--circle-required.radius=1', '--circle-optional.radius=1']).model_dump() == { + 'circle_optional': {'radius': 1, 'diameter': 22, 'perimeter': 24}, + 'circle_required': {'radius': 1, 'diameter': 22, 'perimeter': 23}, + } + + with pytest.raises(SystemExit): + CliApp.run(Settings, cli_args=['--circle-required.radius=1', '--circle-required.diameter=2']) + assert ( + 'error: argument --circle-required.diameter: not allowed with argument --circle-required.radius' + in capsys.readouterr().err + ) + + with pytest.raises(SystemExit): + CliApp.run( + Settings, + cli_args=['--circle-required.radius=1', '--circle-optional.radius=1', '--circle-optional.diameter=2'], + ) + assert ( + 'error: argument --circle-optional.diameter: not allowed with argument --circle-optional.radius' + in capsys.readouterr().err + ) + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + with pytest.raises(SystemExit): + CliApp.run(Settings) + usage = ( + """usage: example.py [-h] [--circle-optional.radius float | + --circle-optional.diameter float | + --circle-optional.perimeter float] + (--circle-required.radius float | + --circle-required.diameter float | + --circle-required.perimeter float)""" + if sys.version_info >= (3, 13) + else """usage: example.py [-h] + [--circle-optional.radius float | --circle-optional.diameter float | --circle-optional.perimeter float] + (--circle-required.radius float | --circle-required.diameter float | --circle-required.perimeter float)""" + ) + assert ( + capsys.readouterr().out + == f"""{usage} + +{ARGPARSE_OPTIONS_TEXT}: + -h, --help show this help message and exit + +circle-optional options (mutually exclusive): + --circle-optional.radius float + (default: None) + --circle-optional.diameter float + (default: None) + --circle-optional.perimeter float + (default: 24.0) + +circle-required options (mutually exclusive): + --circle-required.radius float + (default: 21) + --circle-required.diameter float + (default: 22) + --circle-required.perimeter float + (default: 23) +""" + ) + + +def test_cli_mutually_exclusive_group_exceptions(): + class Circle(CliMutuallyExclusiveGroup): + radius: Optional[float] = 21 + diameter: Optional[float] = 22 + perimeter: Optional[float] = 23 + + class Settings(BaseSettings): + circle: Circle + + parser = CliDummyParser() + with pytest.raises( + SettingsError, + match='cannot connect CLI settings source root parser: group object is missing add_mutually_exclusive_group but is needed for connecting', + ): + CliSettingsSource( + Settings, + root_parser=parser, + parse_args_method=CliDummyParser.parse_args, + add_argument_method=CliDummyParser.add_argument, + add_argument_group_method=CliDummyParser.add_argument_group, + add_parser_method=CliDummySubParsers.add_parser, + add_subparsers_method=CliDummyParser.add_subparsers, + ) + + class SubModel(BaseModel): + pass + + class SettingsInvalidUnion(BaseSettings): + union: Union[Circle, SubModel] + + with pytest.raises(SettingsError, match='cannot use union with CliMutuallyExclusiveGroup'): + CliApp.run(SettingsInvalidUnion) + + class CircleInvalidSubModel(Circle): + square: Optional[SubModel] = None + + class SettingsInvalidOptSubModel(BaseModel): + circle: CircleInvalidSubModel = CircleInvalidSubModel() + + class SettingsInvalidReqSubModel(BaseModel): + circle: CircleInvalidSubModel + + for settings in [SettingsInvalidOptSubModel, SettingsInvalidReqSubModel]: + with pytest.raises(SettingsError, match='cannot have nested models in a CliMutuallyExclusiveGroup'): + CliApp.run(settings) + + class CircleRequiredField(Circle): + length: float + + class SettingsOptCircleReqField(BaseModel): + circle: CircleRequiredField = CircleRequiredField(length=2) + + assert CliApp.run(SettingsOptCircleReqField, cli_args=[]).model_dump() == { + 'circle': {'diameter': 22.0, 'length': 2.0, 'perimeter': 23.0, 'radius': 21.0} + } + + class SettingsInvalidReqCircleReqField(BaseModel): + circle: CircleRequiredField + + with pytest.raises(ValueError, match='mutually exclusive arguments must be optional'): + CliApp.run(SettingsInvalidReqCircleReqField) + + +def test_cli_invalid_abbrev(): + class MySettings(BaseSettings): + bacon: str = '' + badger: str = '' + + with pytest.raises( + SettingsError, + match='error parsing CLI: unrecognized arguments: --bac cli abbrev are invalid for internal parser', + ): + CliApp.run( + MySettings, cli_args=['--bac', 'cli abbrev are invalid for internal parser'], cli_exit_on_error=False + ) + + +def test_cli_submodels_strip_annotated(): + class PolyA(BaseModel): + a: int = 1 + type: Literal['a'] = 'a' + + class PolyB(BaseModel): + b: str = '2' + type: Literal['b'] = 'b' + + def _get_type(model: Union[BaseModel, Dict]) -> str: + if isinstance(model, dict): + return model.get('type', 'a') + return model.type # type: ignore + + Poly = Annotated[Union[Annotated[PolyA, Tag('a')], Annotated[PolyB, Tag('b')]], Discriminator(_get_type)] + + class WithUnion(BaseSettings): + poly: Poly + + assert CliApp.run(WithUnion, ['--poly.type=a']).model_dump() == {'poly': {'a': 1, 'type': 'a'}} + + +def test_cli_kebab_case(capsys, monkeypatch): + class DeepSubModel(BaseModel): + deep_pos_arg: CliPositionalArg[str] + deep_arg: str + + class SubModel(BaseModel): + sub_subcmd: CliSubCommand[DeepSubModel] + sub_arg: str + + class Root(BaseModel): + root_subcmd: CliSubCommand[SubModel] + root_arg: str + + assert CliApp.run( + Root, + cli_args=[ + '--root-arg=hi', + 'root-subcmd', + '--sub-arg=hello', + 'sub-subcmd', + 'hey', + '--deep-arg=bye', + ], + ).model_dump() == { + 'root_arg': 'hi', + 'root_subcmd': { + 'sub_arg': 'hello', + 'sub_subcmd': {'deep_pos_arg': 'hey', 'deep_arg': 'bye'}, + }, + } + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + with pytest.raises(SystemExit): + CliApp.run(Root) + assert ( + capsys.readouterr().out + == f"""usage: example.py [-h] --root-arg str {{root-subcmd}} ... + +{ARGPARSE_OPTIONS_TEXT}: + -h, --help show this help message and exit + --root-arg str (required) + +subcommands: + {{root-subcmd}} + root-subcmd +""" + ) + + m.setattr(sys, 'argv', ['example.py', 'root-subcmd', '--help']) + with pytest.raises(SystemExit): + CliApp.run(Root) + assert ( + capsys.readouterr().out + == f"""usage: example.py root-subcmd [-h] --sub-arg str {{sub-subcmd}} ... + +{ARGPARSE_OPTIONS_TEXT}: + -h, --help show this help message and exit + --sub-arg str (required) + +subcommands: + {{sub-subcmd}} + sub-subcmd +""" + ) + + m.setattr(sys, 'argv', ['example.py', 'root-subcmd', 'sub-subcmd', '--help']) + with pytest.raises(SystemExit): + CliApp.run(Root) + assert ( + capsys.readouterr().out + == f"""usage: example.py root-subcmd sub-subcmd [-h] --deep-arg str DEEP-POS-ARG + +positional arguments: + DEEP-POS-ARG + +{ARGPARSE_OPTIONS_TEXT}: + -h, --help show this help message and exit + --deep-arg str (required) +""" + ) diff --git a/tests/test_source_json.py b/tests/test_source_json.py index e348a6b6..c6360f95 100644 --- a/tests/test_source_json.py +++ b/tests/test_source_json.py @@ -3,6 +3,7 @@ """ import json +from pathlib import Path from typing import Tuple, Type, Union from pydantic import BaseModel @@ -15,6 +16,11 @@ ) +def test_repr() -> None: + source = JsonConfigSettingsSource(BaseSettings(), Path('config.json')) + assert repr(source) == 'JsonConfigSettingsSource(json_file=config.json)' + + def test_json_file(tmp_path): p = tmp_path / '.env' p.write_text( diff --git a/tests/test_source_toml.py b/tests/test_source_toml.py index 29186018..a7230ce2 100644 --- a/tests/test_source_toml.py +++ b/tests/test_source_toml.py @@ -3,6 +3,7 @@ """ import sys +from pathlib import Path from typing import Tuple, Type import pytest @@ -21,6 +22,11 @@ tomli = None +def test_repr() -> None: + source = TomlConfigSettingsSource(BaseSettings(), Path('config.toml')) + assert repr(source) == 'TomlConfigSettingsSource(toml_file=config.toml)' + + @pytest.mark.skipif(sys.version_info <= (3, 11) and tomli is None, reason='tomli/tomllib is not installed') def test_toml_file(tmp_path): p = tmp_path / '.env' diff --git a/tests/test_source_yaml.py b/tests/test_source_yaml.py index fd25de67..1929b9dc 100644 --- a/tests/test_source_yaml.py +++ b/tests/test_source_yaml.py @@ -2,6 +2,7 @@ Test pydantic_settings.YamlConfigSettingsSource. """ +from pathlib import Path from typing import Tuple, Type, Union import pytest @@ -20,6 +21,11 @@ yaml = None +def test_repr() -> None: + source = YamlConfigSettingsSource(BaseSettings(), Path('config.yaml')) + assert repr(source) == 'YamlConfigSettingsSource(yaml_file=config.yaml)' + + @pytest.mark.skipif(yaml, reason='PyYAML is installed') def test_yaml_not_installed(tmp_path): p = tmp_path / '.env'