-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a pydantic-based configuration setup (#353)
Proposal for the next generation of configuration. Switch from dacite to pydantic for deserialization, and split the config into two parts: a dedicated connection to CDF config, and one for application logic.
- Loading branch information
Showing
6 changed files
with
306 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
111 changes: 111 additions & 0 deletions
111
cognite/extractorutils/unstable/configuration/loaders.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import json | ||
from enum import Enum | ||
from io import StringIO | ||
from pathlib import Path | ||
from typing import Dict, Optional, TextIO, Type, TypeVar, Union | ||
|
||
from pydantic import ValidationError | ||
|
||
from cognite.client import CogniteClient | ||
from cognite.extractorutils.configtools.loaders import _load_yaml_dict_raw | ||
from cognite.extractorutils.exceptions import InvalidConfigError | ||
from cognite.extractorutils.unstable.configuration.models import ConfigModel | ||
|
||
_T = TypeVar("_T", bound=ConfigModel) | ||
|
||
|
||
class ConfigFormat(Enum): | ||
JSON = "json" | ||
YAML = "yaml" | ||
|
||
|
||
def load_file(path: Path, schema: Type[_T]) -> _T: | ||
if path.suffix in [".yaml", ".yml"]: | ||
format = ConfigFormat.YAML | ||
elif path.suffix == ".json": | ||
format = ConfigFormat.JSON | ||
else: | ||
raise InvalidConfigError(f"Unknown file type {path.suffix}") | ||
|
||
with open(path, "r") as stream: | ||
return load_io(stream, format, schema) | ||
|
||
|
||
def load_from_cdf( | ||
cognite_client: CogniteClient, external_id: str, schema: Type[_T], revision: Optional[int] = None | ||
) -> _T: | ||
params: Dict[str, Union[str, int]] = {"externalId": external_id} | ||
if revision: | ||
params["revision"] = revision | ||
response = cognite_client.get( | ||
f"/api/v1/projects/{cognite_client.config.project}/odin/config", | ||
params=params, | ||
headers={"cdf-version": "alpha"}, | ||
) | ||
response.raise_for_status() | ||
data = response.json() | ||
return load_io(StringIO(data["config"]), ConfigFormat.YAML, schema) | ||
|
||
|
||
def load_io(stream: TextIO, format: ConfigFormat, schema: Type[_T]) -> _T: | ||
if format == ConfigFormat.JSON: | ||
data = json.load(stream) | ||
|
||
elif format == ConfigFormat.YAML: | ||
data = _load_yaml_dict_raw(stream) | ||
|
||
if "azure-keyvault" in data: | ||
data.pop("azure-keyvault") | ||
if "key-vault" in data: | ||
data.pop("key-vault") | ||
|
||
return load_dict(data, schema) | ||
|
||
|
||
def _make_loc_str(loc: tuple) -> str: | ||
# Remove the body parameter if it is present | ||
if loc[0] == "body": | ||
loc = loc[1:] | ||
|
||
# Create a string from the loc parameter | ||
loc_str = "" | ||
needs_sep = False | ||
for lo in loc: | ||
if not needs_sep: | ||
loc_str = f"{loc_str}{lo}" | ||
needs_sep = True | ||
else: | ||
if isinstance(lo, int): | ||
loc_str = f"{loc_str}[{lo}]" | ||
else: | ||
loc_str = f"{loc_str}.{lo}" | ||
|
||
return loc_str | ||
|
||
|
||
def load_dict(data: dict, schema: Type[_T]) -> _T: | ||
try: | ||
return schema.model_validate(data) | ||
|
||
except ValidationError as e: | ||
messages = [] | ||
for err in e.errors(): | ||
loc = err.get("loc") | ||
if loc is None: | ||
continue | ||
|
||
# Create a string from the loc parameter | ||
loc_str = _make_loc_str(loc) | ||
|
||
if "ctx" in err and "error" in err["ctx"]: | ||
exc = err["ctx"]["error"] | ||
if isinstance(exc, ValueError) or isinstance(exc, AssertionError): | ||
messages.append(f"{loc_str}: {str(exc)}") | ||
continue | ||
|
||
if err.get("type") == "json_invalid": | ||
messages.append(f"{err.get('msg')}: {loc_str}") | ||
else: | ||
messages.append(f"{loc_str}: {err.get('msg')}") | ||
|
||
raise InvalidConfigError(", ".join(messages), details=messages) from e |
159 changes: 159 additions & 0 deletions
159
cognite/extractorutils/unstable/configuration/models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import re | ||
from datetime import timedelta | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union | ||
|
||
from humps import kebabize | ||
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler | ||
from pydantic_core import CoreSchema, core_schema | ||
|
||
from cognite.extractorutils.exceptions import InvalidConfigError | ||
|
||
|
||
class ConfigModel(BaseModel): | ||
model_config = ConfigDict( | ||
alias_generator=kebabize, | ||
populate_by_name=True, | ||
extra="forbid", | ||
# arbitrary_types_allowed=True, | ||
) | ||
|
||
|
||
class _ClientCredentialsConfig(ConfigModel): | ||
type: Literal["client-credentials"] | ||
client_id: str | ||
client_secret: str | ||
token_url: str | ||
scopes: List[str] | ||
resource: Optional[str] = None | ||
audience: Optional[str] = None | ||
|
||
|
||
class _ClientCertificateConfig(ConfigModel): | ||
type: Literal["client-certificate"] | ||
client_id: str | ||
certificate_path: Path | ||
scopes: List[str] | ||
|
||
|
||
AuthenticationConfig = Annotated[Union[_ClientCredentialsConfig, _ClientCertificateConfig], Field(discriminator="type")] | ||
|
||
|
||
class TimeIntervalConfig: | ||
""" | ||
Configuration parameter for setting a time interval | ||
""" | ||
|
||
def __init__(self, expression: str) -> None: | ||
self._interval, self._expression = TimeIntervalConfig._parse_expression(expression) | ||
|
||
@classmethod | ||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: | ||
return core_schema.no_info_after_validator_function(cls, handler(Union[str, int])) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
if not isinstance(other, TimeIntervalConfig): | ||
return NotImplemented | ||
return self._interval == other._interval | ||
|
||
def __hash__(self) -> int: | ||
return hash(self._interval) | ||
|
||
@classmethod | ||
def _parse_expression(cls, expression: str) -> tuple[int, str]: | ||
# First, try to parse pure number and assume seconds (for backwards compatibility) | ||
try: | ||
return int(expression), f"{expression}s" | ||
except ValueError: | ||
pass | ||
|
||
match = re.match(r"(\d+)[ \t]*(s|m|h|d)", expression) | ||
if not match: | ||
raise InvalidConfigError("Invalid interval pattern") | ||
|
||
number, unit = match.groups() | ||
numeric_unit = {"s": 1, "m": 60, "h": 60 * 60, "d": 60 * 60 * 24}[unit] | ||
|
||
return int(number) * numeric_unit, expression | ||
|
||
@property | ||
def seconds(self) -> int: | ||
return self._interval | ||
|
||
@property | ||
def minutes(self) -> float: | ||
return self._interval / 60 | ||
|
||
@property | ||
def hours(self) -> float: | ||
return self._interval / (60 * 60) | ||
|
||
@property | ||
def days(self) -> float: | ||
return self._interval / (60 * 60 * 24) | ||
|
||
@property | ||
def timedelta(self) -> timedelta: | ||
days = self._interval // (60 * 60 * 24) | ||
seconds = self._interval % (60 * 60 * 24) | ||
return timedelta(days=days, seconds=seconds) | ||
|
||
def __int__(self) -> int: | ||
return int(self._interval) | ||
|
||
def __float__(self) -> float: | ||
return float(self._interval) | ||
|
||
def __str__(self) -> str: | ||
return self._expression | ||
|
||
def __repr__(self) -> str: | ||
return self._expression | ||
|
||
|
||
class _ConnectionParameters(ConfigModel): | ||
gzip_compression: bool = False | ||
status_forcelist: List[int] = Field(default_factory=lambda: [429, 502, 503, 504]) | ||
max_retries: int = 10 | ||
max_retries_connect: int = 3 | ||
max_retry_backoff: TimeIntervalConfig = Field(default_factory=lambda: TimeIntervalConfig("30s")) | ||
max_connection_pool_size: int = 50 | ||
ssl_verify: bool = True | ||
proxies: Dict[str, str] = Field(default_factory=dict) | ||
|
||
|
||
class ConnectionConfig(ConfigModel): | ||
project: str | ||
base_url: str | ||
|
||
extraction_pipeline: str | ||
|
||
authentication: AuthenticationConfig | ||
|
||
connection: _ConnectionParameters = Field(default_factory=_ConnectionParameters) | ||
|
||
|
||
class LogLevel(Enum): | ||
CRITICAL = "CRITICAL" | ||
ERROR = "ERROR" | ||
WARNING = "WARNING" | ||
INFO = "INFO" | ||
DEBUG = "DEBUG" | ||
|
||
|
||
class LogFileHandlerConfig(ConfigModel): | ||
path: Path | ||
level: LogLevel | ||
retention: int = 7 | ||
|
||
|
||
class LogConsoleHandlerConfig(ConfigModel): | ||
level: LogLevel | ||
|
||
|
||
LogHandlerConfig = Union[LogFileHandlerConfig, LogConsoleHandlerConfig] | ||
|
||
|
||
class ExtractorConfig(ConfigModel): | ||
log_handlers: List[LogHandlerConfig] = Field(default_factory=lambda: [LogConsoleHandlerConfig(level=LogLevel.INFO)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from io import StringIO | ||
|
||
from cognite.extractorutils.unstable.configuration.loaders import ConfigFormat, load_io | ||
from cognite.extractorutils.unstable.configuration.models import ConnectionConfig | ||
|
||
CONFIG_EXAMPLE_ONLY_REQUIRED = """ | ||
project: test-project | ||
base-url: https://baseurl.com | ||
extraction-pipeline: test-pipeline | ||
authentication: | ||
type: client-credentials | ||
client-id: testid | ||
client-secret: very_secret123 | ||
token-url: https://get-a-token.com/token | ||
scopes: | ||
- scopea | ||
""" | ||
|
||
|
||
def test_load_from_io() -> None: | ||
stream = StringIO(CONFIG_EXAMPLE_ONLY_REQUIRED) | ||
config = load_io(stream, ConfigFormat.YAML, ConnectionConfig) | ||
|
||
assert config.project == "test-project" | ||
assert config.base_url == "https://baseurl.com" | ||
assert config.authentication.type == "client-credentials" | ||
assert config.authentication.client_secret == "very_secret123" |