diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 043e7f0..3a75f7a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,3 +67,7 @@ jobs: - name: Run DuckDB lookup exception handling tests run: | .venv/bin/pytest tests/test_lookup.py -v + + - name: Run DLT Pydantic exception handling (override) test + run: | + .venv/bin/pytest tests/test_dlt_pydantic_override.py -v diff --git a/src/openhound/core/app.py b/src/openhound/core/app.py index 8b66861..971aecf 100644 --- a/src/openhound/core/app.py +++ b/src/openhound/core/app.py @@ -6,6 +6,7 @@ import dlt import duckdb import typer +from dlt.common.libs import pydantic as dlt_pydantic from dlt.common.pipeline import LoadInfo from dlt.extract.resource import DltResource from dlt.extract.source import DltSource @@ -13,6 +14,7 @@ from openhound.cli.collect import collect from openhound.cli.convert import convert from openhound.cli.preproc import preprocess +from openhound.core import validate from openhound.core.asset import BaseAsset, EdgeDef, NodeDef from openhound.core.collect import CollectContext, Collector from openhound.core.convert import ConvertContext, Converter, Method @@ -46,12 +48,14 @@ class Contract(str, Enum): evolve = "evolve" freeze = "freeze" - discard_value = "discard_value" discard_row = "discard_row" class OpenHound: def __init__(self, name: str, source_kind: str, help: str = "OpenGraph collector"): + dlt_pydantic.create_list_model = validate.create_list_model + dlt_pydantic._classify_validation_errors = validate._classify_validation_errors + self.name = name self.source_kind = source_kind self.help = help @@ -67,9 +71,6 @@ def __init__(self, name: str, source_kind: str, help: str = "OpenGraph collector self.dlt_source: DltSource | None = None self.dlt_resources: list[DltResource] = [] self.dlt_transformers: list[DltResource] = [] - self.table_contract: Contract = Contract.evolve - self.data_type_contract: Contract = Contract.freeze - self.columns_contract: Contract = Contract.evolve # Store the graph definitions for this source self.assets: list[BaseAsset] = [] @@ -96,30 +97,38 @@ def wrapper( progress: Progress = typer.Option( Progress.tqdm, help="Select progress tracker option" ), - tables: Contract = typer.Option( - Contract.evolve, - help="Contract applied when data contains newly seen resources/tables previously not collected", - ), - columns: Contract = typer.Option( - Contract.evolve, - help="Contract applied when data contains values/keys not found in the Pydantic model", - ), - data_type: Contract = typer.Option( - Contract.freeze, - help="Contract applied when fields do not match the data types defined in the Pydantic model", - ), + tables_contract: Annotated[ + Contract, + typer.Option( + help="DLT contract applied when data contains newly seen resources/tables previously not collected", + ), + ] = Contract.evolve, + columns_contract: Annotated[ + Contract, + typer.Option( + help="DLT contract applied when data contains values/keys not found in the Pydantic model", + ), + ] = Contract.evolve, + data_type_contract: Annotated[ + Contract, + typer.Option( + help="DLT contract applied when fields do not match the data types defined in the Pydantic model", + ), + ] = Contract.discard_row, ) -> LoadInfo | None: + schema_contract = { + "tables": tables_contract, + "columns": columns_contract, + "data_type": data_type_contract, + } collector = Collector( name=self.name, output_path=output_path, resources=resources, progress=progress, + schema_contract=schema_contract, ) - # TODO: Implement data/table/column contracts - # self.data_type_contract = data_type - # self.columns_contract = columns - # self.table_contract = tables ctx = CollectContext(pipeline=collector) source_method: DltSource = func(ctx) if source_method: diff --git a/src/openhound/core/collect.py b/src/openhound/core/collect.py index 7d6664d..8cdc376 100644 --- a/src/openhound/core/collect.py +++ b/src/openhound/core/collect.py @@ -14,6 +14,12 @@ logger = logging.getLogger(__name__) +DEFAULT_SCHEMA_CONTRACT = { + "tables": "evolve", + "columns": "evolve", + "data_type": "discard_row", +} + class Collector(BasePipeline): def __init__( @@ -22,11 +28,15 @@ def __init__( output_path: Path, resources: list[str] | None = None, progress: Progress = Progress.tqdm, + schema_contract: dict | None = None, ): self.name = name self.output_path = output_path self.resources = resources if resources else [] self.progress = progress + self.schema_contract = ( + schema_contract if schema_contract is not None else DEFAULT_SCHEMA_CONTRACT + ) @property def pipeline(self) -> Pipeline: @@ -56,7 +66,10 @@ def run(self, source_object: DltSource, **kwargs) -> LoadInfo: logger_override.set_handler(self.name) return self._run( - all_resources, write_disposition="replace", loader_file_format="jsonl" + all_resources, + write_disposition="replace", + loader_file_format="jsonl", + schema_contract=self.schema_contract, ) diff --git a/src/openhound/core/validate.py b/src/openhound/core/validate.py new file mode 100644 index 0000000..b7eec48 --- /dev/null +++ b/src/openhound/core/validate.py @@ -0,0 +1,128 @@ +# This is a patched version of DLT's model validator which adds logging when resources fail +# pydantic validation and the schema contract is set to discard_row +# https://github.com/dlt-hub/dlt/blob/devel/dlt/common/libs/pydantic.py + +import logging +from typing import Any, Type + +from dlt.common.schema import DataValidationError +from dlt.common.schema.typing import TSchemaEvolutionMode +from dlt.common.typing import TDataItem +from pydantic import BaseModel, ValidationError, create_model +from pydantic.functional_validators import WrapValidator +from typing_extensions import Annotated + +logger = logging.getLogger("dlt") + + +def create_list_model( + model: type[BaseModel], + column_mode: TSchemaEvolutionMode = "freeze", + data_mode: TSchemaEvolutionMode = "freeze", +) -> type[BaseModel]: + """Creates a model from `model` for validating list of items in batch.""" + if column_mode == "discard_row" or data_mode == "discard_row": + + def _lenient_item_validator(value: Any, handler: Any) -> BaseModel | None: + try: + return handler(value) + except ValidationError as val_err: + for err in val_err.errors(): + if err["type"] == "model_type": + raise + + logger.warning( + "DLT discarded row during listed Pydantic validation", + extra={ + "resource": model.__name__, + "pydantic_errors": val_err.errors( + include_input=False, include_context=False + ), + }, + ) + return None + except Exception: + return None + + item_type = Annotated[model | None, WrapValidator(_lenient_item_validator)] # type: ignore[valid-type] + return create_model( + "LenientList" + model.__name__, + items=(list[item_type], ...), # type: ignore[valid-type] + ) + + return create_model( + "List" + model.__name__, + items=(list[model], ...), # type: ignore[valid-type] + ) + + +def _classify_validation_errors( + table_name: str, + model: Type[BaseModel], + item: TDataItem, + exc: ValidationError, + column_mode: TSchemaEvolutionMode, + data_mode: TSchemaEvolutionMode, +) -> None: + """Classifies validation errors and raises DataValidationError for freeze mode. + + For discard_row mode, returns without raising so the caller can discard the item. + For model_type errors (item is not a mapping), always re-raises. + """ + for err in exc.errors(): + if err["type"] == "model_type": + raise exc + if err["type"] == "extra_forbidden": + if column_mode == "freeze": + raise DataValidationError( + None, + table_name, + str(err["loc"]), + "columns", + "freeze", + model, + {"columns": "freeze"}, + item, + err["msg"], + ) from exc + elif column_mode == "discard_row": + logger.warning( + "DLT discarded row during Pydantic validation", + extra={ + "resource": table_name, + "pydantic_errors": exc.errors( + include_input=False, include_context=False + ), + }, + ) + return + raise NotImplementedError( + f"`{column_mode=:}` not implemented for Pydantic validation" + ) + else: + if data_mode == "freeze": + raise DataValidationError( + None, + table_name, + str(err["loc"]), + "data_type", + "freeze", + model, + {"data_type": "freeze"}, + item, + err["msg"], + ) from exc + elif data_mode == "discard_row": + logger.warning( + "DLT discarded row during Pydantic validation", + extra={ + "resource": table_name, + "pydantic_errors": exc.errors( + include_input=False, include_context=False + ), + }, + ) + return + raise NotImplementedError( + f"`{data_mode=:}` not implemented for Pydantic validation" + ) diff --git a/tests/test_dlt_pydantic_override.py b/tests/test_dlt_pydantic_override.py new file mode 100644 index 0000000..8cf32e0 --- /dev/null +++ b/tests/test_dlt_pydantic_override.py @@ -0,0 +1,59 @@ +import logging + +import pytest +from dlt.common.libs import pydantic as dlt_pydantic +from dlt.common.schema.exceptions import DataValidationError +from dlt.extract.validation import PydanticValidator +from pydantic import BaseModel + +from openhound.core import validate + + +class ExampleModel(BaseModel): + id: int + required_field: str + + +@pytest.fixture(autouse=True) +def patch_dlt_validation(monkeypatch): + monkeypatch.setattr(dlt_pydantic, "create_list_model", validate.create_list_model) + monkeypatch.setattr( + dlt_pydantic, + "_classify_validation_errors", + validate._classify_validation_errors, + ) + + +def _validator(column_mode: str = "evolve", data_mode: str = "discard_row"): + validator = PydanticValidator(ExampleModel, column_mode, data_mode) + validator.table_name = "users" + return validator + + +def test_dlt_pydantic_discard_log(caplog): + """Checks if the validator (with discard_row) logs data type errors""" + caplog.set_level(logging.WARNING, logger="dlt") + + result = _validator()(dict(id="not-an-int")) + + assert result is None + assert len(caplog.records) == 1 + record = caplog.records[0] + + assert record.resource == "users" + assert record.pydantic_errors[0]["type"] == "int_parsing" + assert record.pydantic_errors[0]["loc"] == ("id",) + assert record.pydantic_errors[1]["type"] == "missing" + assert record.pydantic_errors[1]["loc"] == ("required_field",) + + +def test_dlt_pydantic_freeze_exception(caplog): + """Check if DLT raises an exception instead of continuing when the data mode is set to freeze""" + caplog.set_level(logging.WARNING, logger="dlt") + + with pytest.raises(DataValidationError) as exc_info: + _validator(data_mode="freeze")(dict(id="not-an-int")) + + assert "id" in exc_info.value.data_item + assert exc_info.value.data_item["id"] == "not-an-int" + assert exc_info.value.contract_mode == "freeze"