Skip to content

Commit

Permalink
Add a pydantic-based configuration setup (#353)
Browse files Browse the repository at this point in the history
Proposal for the next generation of configuration. Switch from dacite
to pydantic for deserialization, and split the config into two parts: a
dedicated connection to CDF config, and one for application logic.
  • Loading branch information
mathialo authored Sep 9, 2024
1 parent f33dba9 commit 39053c6
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 1 deletion.
6 changes: 5 additions & 1 deletion cognite/extractorutils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.


from typing import List, Optional


class InvalidConfigError(Exception):
"""
Exception thrown from ``load_yaml`` and ``load_yaml_dict`` if config file is invalid. This can be due to
Expand All @@ -22,9 +25,10 @@ class InvalidConfigError(Exception):
* Unkown fields
"""

def __init__(self, message: str):
def __init__(self, message: str, details: Optional[List[str]] = None):
super(InvalidConfigError, self).__init__()
self.message = message
self.details = details

def __str__(self) -> str:
return f"Invalid config: {self.message}"
Expand Down
Empty file.
111 changes: 111 additions & 0 deletions cognite/extractorutils/unstable/configuration/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
from enum import Enum
from io import StringIO
from pathlib import Path
from typing import Dict, Optional, TextIO, Type, TypeVar, Union

from pydantic import ValidationError

from cognite.client import CogniteClient
from cognite.extractorutils.configtools.loaders import _load_yaml_dict_raw
from cognite.extractorutils.exceptions import InvalidConfigError
from cognite.extractorutils.unstable.configuration.models import ConfigModel

_T = TypeVar("_T", bound=ConfigModel)


class ConfigFormat(Enum):
JSON = "json"
YAML = "yaml"


def load_file(path: Path, schema: Type[_T]) -> _T:
if path.suffix in [".yaml", ".yml"]:
format = ConfigFormat.YAML
elif path.suffix == ".json":
format = ConfigFormat.JSON
else:
raise InvalidConfigError(f"Unknown file type {path.suffix}")

with open(path, "r") as stream:
return load_io(stream, format, schema)


def load_from_cdf(
cognite_client: CogniteClient, external_id: str, schema: Type[_T], revision: Optional[int] = None
) -> _T:
params: Dict[str, Union[str, int]] = {"externalId": external_id}
if revision:
params["revision"] = revision
response = cognite_client.get(
f"/api/v1/projects/{cognite_client.config.project}/odin/config",
params=params,
headers={"cdf-version": "alpha"},
)
response.raise_for_status()
data = response.json()
return load_io(StringIO(data["config"]), ConfigFormat.YAML, schema)


def load_io(stream: TextIO, format: ConfigFormat, schema: Type[_T]) -> _T:
if format == ConfigFormat.JSON:
data = json.load(stream)

elif format == ConfigFormat.YAML:
data = _load_yaml_dict_raw(stream)

if "azure-keyvault" in data:
data.pop("azure-keyvault")
if "key-vault" in data:
data.pop("key-vault")

return load_dict(data, schema)


def _make_loc_str(loc: tuple) -> str:
# Remove the body parameter if it is present
if loc[0] == "body":
loc = loc[1:]

# Create a string from the loc parameter
loc_str = ""
needs_sep = False
for lo in loc:
if not needs_sep:
loc_str = f"{loc_str}{lo}"
needs_sep = True
else:
if isinstance(lo, int):
loc_str = f"{loc_str}[{lo}]"
else:
loc_str = f"{loc_str}.{lo}"

return loc_str


def load_dict(data: dict, schema: Type[_T]) -> _T:
try:
return schema.model_validate(data)

except ValidationError as e:
messages = []
for err in e.errors():
loc = err.get("loc")
if loc is None:
continue

# Create a string from the loc parameter
loc_str = _make_loc_str(loc)

if "ctx" in err and "error" in err["ctx"]:
exc = err["ctx"]["error"]
if isinstance(exc, ValueError) or isinstance(exc, AssertionError):
messages.append(f"{loc_str}: {str(exc)}")
continue

if err.get("type") == "json_invalid":
messages.append(f"{err.get('msg')}: {loc_str}")
else:
messages.append(f"{loc_str}: {err.get('msg')}")

raise InvalidConfigError(", ".join(messages), details=messages) from e
159 changes: 159 additions & 0 deletions cognite/extractorutils/unstable/configuration/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import re
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Union

from humps import kebabize
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

from cognite.extractorutils.exceptions import InvalidConfigError


class ConfigModel(BaseModel):
model_config = ConfigDict(
alias_generator=kebabize,
populate_by_name=True,
extra="forbid",
# arbitrary_types_allowed=True,
)


class _ClientCredentialsConfig(ConfigModel):
type: Literal["client-credentials"]
client_id: str
client_secret: str
token_url: str
scopes: List[str]
resource: Optional[str] = None
audience: Optional[str] = None


class _ClientCertificateConfig(ConfigModel):
type: Literal["client-certificate"]
client_id: str
certificate_path: Path
scopes: List[str]


AuthenticationConfig = Annotated[Union[_ClientCredentialsConfig, _ClientCertificateConfig], Field(discriminator="type")]


class TimeIntervalConfig:
"""
Configuration parameter for setting a time interval
"""

def __init__(self, expression: str) -> None:
self._interval, self._expression = TimeIntervalConfig._parse_expression(expression)

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(Union[str, int]))

def __eq__(self, other: object) -> bool:
if not isinstance(other, TimeIntervalConfig):
return NotImplemented
return self._interval == other._interval

def __hash__(self) -> int:
return hash(self._interval)

@classmethod
def _parse_expression(cls, expression: str) -> tuple[int, str]:
# First, try to parse pure number and assume seconds (for backwards compatibility)
try:
return int(expression), f"{expression}s"
except ValueError:
pass

match = re.match(r"(\d+)[ \t]*(s|m|h|d)", expression)
if not match:
raise InvalidConfigError("Invalid interval pattern")

number, unit = match.groups()
numeric_unit = {"s": 1, "m": 60, "h": 60 * 60, "d": 60 * 60 * 24}[unit]

return int(number) * numeric_unit, expression

@property
def seconds(self) -> int:
return self._interval

@property
def minutes(self) -> float:
return self._interval / 60

@property
def hours(self) -> float:
return self._interval / (60 * 60)

@property
def days(self) -> float:
return self._interval / (60 * 60 * 24)

@property
def timedelta(self) -> timedelta:
days = self._interval // (60 * 60 * 24)
seconds = self._interval % (60 * 60 * 24)
return timedelta(days=days, seconds=seconds)

def __int__(self) -> int:
return int(self._interval)

def __float__(self) -> float:
return float(self._interval)

def __str__(self) -> str:
return self._expression

def __repr__(self) -> str:
return self._expression


class _ConnectionParameters(ConfigModel):
gzip_compression: bool = False
status_forcelist: List[int] = Field(default_factory=lambda: [429, 502, 503, 504])
max_retries: int = 10
max_retries_connect: int = 3
max_retry_backoff: TimeIntervalConfig = Field(default_factory=lambda: TimeIntervalConfig("30s"))
max_connection_pool_size: int = 50
ssl_verify: bool = True
proxies: Dict[str, str] = Field(default_factory=dict)


class ConnectionConfig(ConfigModel):
project: str
base_url: str

extraction_pipeline: str

authentication: AuthenticationConfig

connection: _ConnectionParameters = Field(default_factory=_ConnectionParameters)


class LogLevel(Enum):
CRITICAL = "CRITICAL"
ERROR = "ERROR"
WARNING = "WARNING"
INFO = "INFO"
DEBUG = "DEBUG"


class LogFileHandlerConfig(ConfigModel):
path: Path
level: LogLevel
retention: int = 7


class LogConsoleHandlerConfig(ConfigModel):
level: LogLevel


LogHandlerConfig = Union[LogFileHandlerConfig, LogConsoleHandlerConfig]


class ExtractorConfig(ConfigModel):
log_handlers: List[LogHandlerConfig] = Field(default_factory=lambda: [LogConsoleHandlerConfig(level=LogLevel.INFO)])
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ azure-identity = "^1.14.0"
azure-keyvault-secrets = "^4.7.0"
orjson = "^3.10.3"
httpx = "^0.27.0"
pydantic = "^2.8.2"
pyhumps = "^3.8.0"

[tool.poetry.extras]
experimental = ["cognite-sdk-experimental"]
Expand Down
29 changes: 29 additions & 0 deletions tests/test_unstable/test_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from io import StringIO

from cognite.extractorutils.unstable.configuration.loaders import ConfigFormat, load_io
from cognite.extractorutils.unstable.configuration.models import ConnectionConfig

CONFIG_EXAMPLE_ONLY_REQUIRED = """
project: test-project
base-url: https://baseurl.com
extraction-pipeline: test-pipeline
authentication:
type: client-credentials
client-id: testid
client-secret: very_secret123
token-url: https://get-a-token.com/token
scopes:
- scopea
"""


def test_load_from_io() -> None:
stream = StringIO(CONFIG_EXAMPLE_ONLY_REQUIRED)
config = load_io(stream, ConfigFormat.YAML, ConnectionConfig)

assert config.project == "test-project"
assert config.base_url == "https://baseurl.com"
assert config.authentication.type == "client-credentials"
assert config.authentication.client_secret == "very_secret123"

0 comments on commit 39053c6

Please sign in to comment.