Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 168 additions & 7 deletions scripts/postprocess_generated_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Applied to `_models.py`:
- Fix discriminator field names that use camelCase instead of snake_case (known issue with discriminators on schemas
referenced from array items).
- Add `alias_generator=to_camel` to every model's `ConfigDict` and drop the per-field `Field(alias=...)` whenever the
alias is just the camelCase of the snake_case field name. Only irregular conversions (all-caps usage keys,
`gitHubGistUrl`, `schema`, ...) keep an explicit alias. The wire format is unchanged: an explicit `Field(alias=...)`
still wins over the generator in Pydantic, and `populate_by_name=True` keeps snake_case input working.
- Rewrite every `class X(StrEnum)` as `X = Literal[...]` so downstream code can pass plain strings
(and reuse the named alias in resource-client signatures) instead of enum members.
- Move the resulting `X = Literal[...]` definitions into `_literals.py`, leaving `_models.py` importing them — so
Expand All @@ -28,6 +32,8 @@
from pathlib import Path
from typing import TYPE_CHECKING

from pydantic.alias_generators import to_camel

if TYPE_CHECKING:
from apify_client._docs import GroupName

Expand Down Expand Up @@ -98,6 +104,154 @@ def fix_discriminators(content: str) -> str:
return content


def _ensure_to_camel_import(content: str) -> str:
"""Add `from pydantic.alias_generators import to_camel` after the `from pydantic import ...` line."""
if 'from pydantic.alias_generators import to_camel' in content:
return content
return re.sub(
r'(from pydantic import [^\n]+\n)',
r'\1from pydantic.alias_generators import to_camel\n',
content,
count=1,
)


def _add_alias_generator_to_configs(content: str) -> str:
"""Add `alias_generator=to_camel` to every model's `ConfigDict(extra='allow', populate_by_name=True)`.

Collapses the multi-line config datamodel-codegen emits into a single line (it fits the line length), which keeps
the generated output smaller now that per-field aliases are gone. Idempotent: a config that already carries
`alias_generator` has it between `populate_by_name=True,` and the closing paren, so the pattern no longer matches.
"""
return re.sub(
r"model_config = ConfigDict\(\s*extra='allow',\s*populate_by_name=True,\s*\)",
"model_config = ConfigDict(extra='allow', populate_by_name=True, alias_generator=to_camel)",
content,
)


def _class_uses_camel_generator(class_node: ast.ClassDef) -> bool:
"""Return True if `class_node` declares `model_config = ConfigDict(..., alias_generator=to_camel)`."""
for stmt in class_node.body:
if isinstance(stmt, ast.Assign):
targets = stmt.targets
elif isinstance(stmt, ast.AnnAssign):
targets = [stmt.target]
else:
continue
if not any(isinstance(t, ast.Name) and t.id == 'model_config' for t in targets):
continue
value = stmt.value
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name) and value.func.id == 'ConfigDict':
return any(
kw.arg == 'alias_generator' and isinstance(kw.value, ast.Name) and kw.value.id == 'to_camel'
for kw in value.keywords
)
return False


def _annotation_without_regular_alias(annotation: ast.expr, field_name: str) -> str | None:
"""Return the unparsed annotation with a regular `Field(alias=...)` removed, or None to leave it untouched.

An alias is "regular" when `to_camel(field_name)` reproduces it, so the model's `alias_generator` already covers
it. Irregular aliases (e.g. `gitHubGistUrl`, all-caps usage keys, `schema`) are kept. When the alias was the
`Field` call's only content, the now-empty `Field` is dropped and `Annotated[T, Field()]` collapses to `T`.
"""
if not (
isinstance(annotation, ast.Subscript)
and isinstance(annotation.value, ast.Name)
and annotation.value.id == 'Annotated'
and isinstance(annotation.slice, ast.Tuple)
):
return None

elts = annotation.slice.elts
field_idx = next(
(
i
for i, elt in enumerate(elts)
if i > 0
and isinstance(elt, ast.Call)
and isinstance(elt.func, ast.Name)
and elt.func.id == 'Field'
and _extract_alias_from_field_call(elt) is not None
),
None,
)
if field_idx is None:
return None

field_call = elts[field_idx]
assert isinstance(field_call, ast.Call) # noqa: S101
if to_camel(field_name) != _extract_alias_from_field_call(field_call):
return None

remaining_keywords = [kw for kw in field_call.keywords if kw.arg != 'alias']
new_elts = list(elts)
if remaining_keywords or field_call.args:
new_elts[field_idx] = ast.Call(func=field_call.func, args=list(field_call.args), keywords=remaining_keywords)
else:
del new_elts[field_idx]

if len(new_elts) == 1:
result_node: ast.expr = new_elts[0]
else:
result_node = ast.Subscript(
value=annotation.value,
slice=ast.Tuple(elts=new_elts, ctx=ast.Load()),
ctx=ast.Load(),
)
return ast.unparse(result_node)


def _strip_regular_field_aliases(content: str) -> str:
"""Drop every per-field `Field(alias=...)` whose alias is just the camelCase of the field name.

Operates on UTF-8 byte offsets (matching `ast` `col_offset` semantics) so multi-line annotations and any
non-ASCII content splice correctly. Only fields on models that carry `alias_generator=to_camel` are touched.
"""
tree = ast.parse(content)
data = bytearray(content.encode())
line_byte_starts = [0, *[i + 1 for i, byte in enumerate(data) if byte == ord('\n')]]

edits: list[tuple[int, int, bytes]] = []
for node in tree.body:
if not isinstance(node, ast.ClassDef) or not _class_uses_camel_generator(node):
continue
for stmt in node.body:
if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name):
continue
if stmt.target.id == 'model_config':
continue
replacement = _annotation_without_regular_alias(stmt.annotation, stmt.target.id)
if replacement is None:
continue
ann = stmt.annotation
assert ann.end_lineno is not None # noqa: S101
assert ann.end_col_offset is not None # noqa: S101
start = line_byte_starts[ann.lineno - 1] + ann.col_offset
end = line_byte_starts[ann.end_lineno - 1] + ann.end_col_offset
edits.append((start, end, replacement.encode()))

# Splice in reverse byte order so earlier offsets stay valid after each edit.
for start, end, replacement_bytes in sorted(edits, key=lambda e: e[0], reverse=True):
data[start:end] = replacement_bytes

return data.decode()


def apply_camel_alias_generator(content: str) -> str:
"""Move per-field camelCase aliasing onto a shared `alias_generator=to_camel`, keeping only irregular aliases.

Adds the `to_camel` import, injects `alias_generator=to_camel` into every model `ConfigDict`, then strips the
now-redundant `Field(alias=...)` entries. The wire format is unchanged (verified by tests): Pydantic lets an
explicit alias override the generator, and `populate_by_name=True` keeps snake_case input working.
"""
content = _ensure_to_camel_import(content)
content = _add_alias_generator_to_configs(content)
return _strip_regular_field_aliases(content)


def convert_enums_to_literals(content: str) -> str:
"""Rewrite every `class X(StrEnum): ...` into an `X = Literal[...]` alias.

Expand Down Expand Up @@ -404,25 +558,31 @@ def _extract_alias_from_field_call(field_call: ast.Call) -> str | None:
def _extract_class_field_aliases(class_node: ast.ClassDef) -> dict[str, str]:
"""Return `{snake_field: api_field}` for every annotated field declared on `class_node`.

Fields without a `Field(alias=...)` map to themselves (their declared Python name matches the API name — typical
for single-word fields like `url`, `id`).
The API spelling is resolved in priority order: an explicit `Field(alias=...)` wins; otherwise, on a model that
carries `alias_generator=to_camel`, the name is run through `to_camel` (matching Pydantic at runtime); otherwise
the field maps to itself (single-word fields like `url`, `id`, or models without the generator).
"""
uses_camel = _class_uses_camel_generator(class_node)
aliases: dict[str, str] = {}
for stmt in class_node.body:
if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name):
continue
field_name = stmt.target.id
if field_name == 'model_config':
continue
# Default: no alias means snake name == API name.
api_name = field_name
# Walk the annotation to find a nested `Field(alias='...')` call inside `Annotated[...]`.
explicit_alias: str | None = None
for sub in ast.walk(stmt.annotation):
if isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name) and sub.func.id == 'Field':
found = _extract_alias_from_field_call(sub)
if found is not None:
api_name = found
explicit_alias = _extract_alias_from_field_call(sub)
if explicit_alias is not None:
break
if explicit_alias is not None:
api_name = explicit_alias
elif uses_camel:
api_name = to_camel(field_name)
else:
api_name = field_name
aliases[field_name] = api_name
return aliases

Expand Down Expand Up @@ -588,6 +748,7 @@ def postprocess_models(models_path: Path, literals_path: Path) -> list[Path]:
"""
original = models_path.read_text()
fixed = fix_discriminators(original)
fixed = apply_camel_alias_generator(fixed)
fixed = convert_enums_to_literals(fixed)
fixed = add_docs_group_decorators(fixed, 'Models')
models_content, literals_content = split_literals_to_file(fixed)
Expand Down
Loading