Skip to content
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ jobs:
- name: Run DuckDB lookup exception handling tests
run: |
.venv/bin/pytest tests/test_lookup.py -v
- name: Run DLT Pydantic exception handling (override) test
run: |
.venv/bin/pytest tests/test_dlt_pydantic_override.py -v
49 changes: 29 additions & 20 deletions src/openhound/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import dlt
import duckdb
import typer
from dlt.common.libs import pydantic as dlt_pydantic
from dlt.common.pipeline import LoadInfo
from dlt.extract.resource import DltResource
from dlt.extract.source import DltSource

from openhound.cli.collect import collect
from openhound.cli.convert import convert
from openhound.cli.preproc import preprocess
from openhound.core import validate
from openhound.core.asset import BaseAsset, EdgeDef, NodeDef
from openhound.core.collect import CollectContext, Collector
from openhound.core.convert import ConvertContext, Converter, Method
Expand Down Expand Up @@ -46,12 +48,14 @@
class Contract(str, Enum):
evolve = "evolve"
freeze = "freeze"
discard_value = "discard_value"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discard_value is a supported contract but not supported when using Pydantic models, which is what we use for all the collectors

discard_row = "discard_row"


class OpenHound:
def __init__(self, name: str, source_kind: str, help: str = "OpenGraph collector"):
dlt_pydantic.create_list_model = validate.create_list_model
dlt_pydantic._classify_validation_errors = validate._classify_validation_errors

self.name = name
self.source_kind = source_kind
self.help = help
Expand All @@ -67,9 +71,6 @@ def __init__(self, name: str, source_kind: str, help: str = "OpenGraph collector
self.dlt_source: DltSource | None = None
self.dlt_resources: list[DltResource] = []
self.dlt_transformers: list[DltResource] = []
self.table_contract: Contract = Contract.evolve
self.data_type_contract: Contract = Contract.freeze
self.columns_contract: Contract = Contract.evolve

# Store the graph definitions for this source
self.assets: list[BaseAsset] = []
Expand All @@ -96,30 +97,38 @@ def wrapper(
progress: Progress = typer.Option(
Progress.tqdm, help="Select progress tracker option"
),
tables: Contract = typer.Option(
Contract.evolve,
help="Contract applied when data contains newly seen resources/tables previously not collected",
),
columns: Contract = typer.Option(
Contract.evolve,
help="Contract applied when data contains values/keys not found in the Pydantic model",
),
data_type: Contract = typer.Option(
Contract.freeze,
help="Contract applied when fields do not match the data types defined in the Pydantic model",
),
tables_contract: Annotated[
Contract,
typer.Option(
help="DLT contract applied when data contains newly seen resources/tables previously not collected",
),
] = Contract.evolve,
columns_contract: Annotated[
Contract,
typer.Option(
help="DLT contract applied when data contains values/keys not found in the Pydantic model",
),
] = Contract.evolve,
data_type_contract: Annotated[
Contract,
typer.Option(
help="DLT contract applied when fields do not match the data types defined in the Pydantic model",
),
] = Contract.discard_row,
) -> LoadInfo | None:
schema_contract = {
"tables": tables_contract,
"columns": columns_contract,
"data_type": data_type_contract,
}
collector = Collector(
name=self.name,
output_path=output_path,
resources=resources,
progress=progress,
schema_contract=schema_contract,
)

# TODO: Implement data/table/column contracts
# self.data_type_contract = data_type
# self.columns_contract = columns
# self.table_contract = tables
ctx = CollectContext(pipeline=collector)
source_method: DltSource = func(ctx)
if source_method:
Expand Down
15 changes: 14 additions & 1 deletion src/openhound/core/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

logger = logging.getLogger(__name__)

DEFAULT_SCHEMA_CONTRACT = {
"tables": "evolve",
"columns": "evolve",
"data_type": "discard_row",
}


class Collector(BasePipeline):
def __init__(
Expand All @@ -22,11 +28,15 @@ def __init__(
output_path: Path,
resources: list[str] | None = None,
progress: Progress = Progress.tqdm,
schema_contract: dict | None = None,
):
self.name = name
self.output_path = output_path
self.resources = resources if resources else []
self.progress = progress
self.schema_contract = (
schema_contract if schema_contract is not None else DEFAULT_SCHEMA_CONTRACT
)

@property
def pipeline(self) -> Pipeline:
Expand Down Expand Up @@ -56,7 +66,10 @@ def run(self, source_object: DltSource, **kwargs) -> LoadInfo:

logger_override.set_handler(self.name)
return self._run(
all_resources, write_disposition="replace", loader_file_format="jsonl"
all_resources,
write_disposition="replace",
loader_file_format="jsonl",
schema_contract=self.schema_contract,
)


Expand Down
128 changes: 128 additions & 0 deletions src/openhound/core/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# This is a patched version of DLT's model validator which adds logging when resources fail
# pydantic validation and the schema contract is set to discard_row
# https://github.com/dlt-hub/dlt/blob/devel/dlt/common/libs/pydantic.py

import logging
from typing import Any, Type

from dlt.common.schema import DataValidationError
from dlt.common.schema.typing import TSchemaEvolutionMode
from dlt.common.typing import TDataItem
from pydantic import BaseModel, ValidationError, create_model
from pydantic.functional_validators import WrapValidator
from typing_extensions import Annotated

logger = logging.getLogger("dlt")


def create_list_model(
model: type[BaseModel],
column_mode: TSchemaEvolutionMode = "freeze",
data_mode: TSchemaEvolutionMode = "freeze",
) -> type[BaseModel]:
"""Creates a model from `model` for validating list of items in batch."""
if column_mode == "discard_row" or data_mode == "discard_row":

def _lenient_item_validator(value: Any, handler: Any) -> BaseModel | None:
try:
return handler(value)
except ValidationError as val_err:
for err in val_err.errors():
if err["type"] == "model_type":
raise

logger.warning(
"DLT discarded row during listed Pydantic validation",
extra={
"resource": model.__name__,
"pydantic_errors": val_err.errors(
include_input=False, include_context=False
),
},
)
return None
except Exception:
return None

item_type = Annotated[model | None, WrapValidator(_lenient_item_validator)] # type: ignore[valid-type]
return create_model(
"LenientList" + model.__name__,
items=(list[item_type], ...), # type: ignore[valid-type]
)

return create_model(
"List" + model.__name__,
items=(list[model], ...), # type: ignore[valid-type]
)


def _classify_validation_errors(
table_name: str,
model: Type[BaseModel],
item: TDataItem,
exc: ValidationError,
column_mode: TSchemaEvolutionMode,
data_mode: TSchemaEvolutionMode,
) -> None:
"""Classifies validation errors and raises DataValidationError for freeze mode.
For discard_row mode, returns without raising so the caller can discard the item.
For model_type errors (item is not a mapping), always re-raises.
"""
for err in exc.errors():
if err["type"] == "model_type":
raise exc
if err["type"] == "extra_forbidden":
if column_mode == "freeze":
raise DataValidationError(
None,
table_name,
str(err["loc"]),
"columns",
"freeze",
model,
{"columns": "freeze"},
item,
err["msg"],
) from exc
elif column_mode == "discard_row":
logger.warning(
"DLT discarded row during Pydantic validation",
extra={
"resource": table_name,
"pydantic_errors": exc.errors(
include_input=False, include_context=False
),
},
)
return
raise NotImplementedError(
f"`{column_mode=:}` not implemented for Pydantic validation"
)
else:
if data_mode == "freeze":
raise DataValidationError(
None,
table_name,
str(err["loc"]),
"data_type",
"freeze",
model,
{"data_type": "freeze"},
item,
err["msg"],
) from exc
elif data_mode == "discard_row":
logger.warning(
"DLT discarded row during Pydantic validation",
extra={
"resource": table_name,
"pydantic_errors": exc.errors(
include_input=False, include_context=False
),
},
)
return
raise NotImplementedError(
f"`{data_mode=:}` not implemented for Pydantic validation"
)
59 changes: 59 additions & 0 deletions tests/test_dlt_pydantic_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging

import pytest
from dlt.common.libs import pydantic as dlt_pydantic
from dlt.common.schema.exceptions import DataValidationError
from dlt.extract.validation import PydanticValidator
from pydantic import BaseModel

from openhound.core import validate


class ExampleModel(BaseModel):
id: int
required_field: str


@pytest.fixture(autouse=True)
def patch_dlt_validation(monkeypatch):
monkeypatch.setattr(dlt_pydantic, "create_list_model", validate.create_list_model)
monkeypatch.setattr(
dlt_pydantic,
"_classify_validation_errors",
validate._classify_validation_errors,
)


def _validator(column_mode: str = "evolve", data_mode: str = "discard_row"):
validator = PydanticValidator(ExampleModel, column_mode, data_mode)
validator.table_name = "users"
return validator


def test_dlt_pydantic_discard_log(caplog):
"""Checks if the validator (with discard_row) logs data type errors"""
caplog.set_level(logging.WARNING, logger="dlt")

result = _validator()(dict(id="not-an-int"))

assert result is None
assert len(caplog.records) == 1
record = caplog.records[0]

assert record.resource == "users"
assert record.pydantic_errors[0]["type"] == "int_parsing"
assert record.pydantic_errors[0]["loc"] == ("id",)
assert record.pydantic_errors[1]["type"] == "missing"
assert record.pydantic_errors[1]["loc"] == ("required_field",)


def test_dlt_pydantic_freeze_exception(caplog):
"""Check if DLT raises an exception instead of continuing when the data mode is set to freeze"""
caplog.set_level(logging.WARNING, logger="dlt")

with pytest.raises(DataValidationError) as exc_info:
_validator(data_mode="freeze")(dict(id="not-an-int"))

assert "id" in exc_info.value.data_item
assert exc_info.value.data_item["id"] == "not-an-int"
assert exc_info.value.contract_mode == "freeze"
Loading