|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Inject `model_config = ConfigDict(extra='forbid')` into every generated |
| 3 | +Pydantic BaseModel and RootModel class. |
| 4 | +
|
| 5 | +datamodel-code-generator does not emit a config block when the source |
| 6 | +OpenAPI spec lacks `additionalProperties: false`. Springdoc never emits |
| 7 | +that key, so we patch every generated class here. |
| 8 | +
|
| 9 | +This implements policies P1 (response extras forbidden) and P2 (request |
| 10 | +extras forbidden) from `mini/cowork/design/040-codegen-policies.md`. |
| 11 | +
|
| 12 | +The transform is purely syntactic: scan each line, find `class Foo(BaseModel):` |
| 13 | +or `class Foo(RootModel[...]):` and inject `model_config = ConfigDict(...)` |
| 14 | +on the next non-empty indented line. |
| 15 | +
|
| 16 | +Idempotent: skips classes that already declare `model_config`. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import re |
| 22 | +import sys |
| 23 | +from pathlib import Path |
| 24 | + |
| 25 | +# RootModel subclasses cannot set `extra='forbid'` (Pydantic raises |
| 26 | +# `root-model-extra`), so skip them. Their behavior is governed by the |
| 27 | +# inner type, which on its own enforces strict validation. |
| 28 | +CLASS_RE = re.compile(r"^class\s+([A-Za-z_][\w]*)\s*\(\s*(BaseModel)\s*\)\s*:\s*$") |
| 29 | +CONFIG_LINE = " model_config = ConfigDict(extra='forbid')" |
| 30 | + |
| 31 | + |
| 32 | +# StrEnum members that shadow inherited str methods need a `# type: ignore` |
| 33 | +# because mypy thinks they're overriding the base method with an incompatible |
| 34 | +# type. Listed explicitly so we get failures (instead of silent no-ops) when |
| 35 | +# datamodel-codegen renames things. |
| 36 | +STR_ENUM_COLLISIONS = { |
| 37 | + # member name -> mypy ignore code |
| 38 | + "count": "assignment", |
| 39 | + "index": "assignment", |
| 40 | + "title": "assignment", |
| 41 | + "lower": "assignment", |
| 42 | + "upper": "assignment", |
| 43 | + "format": "assignment", |
| 44 | +} |
| 45 | + |
| 46 | +STR_ENUM_RE = re.compile(r"^class\s+([A-Za-z_][\w]*)\s*\(\s*StrEnum\s*\)\s*:\s*$") |
| 47 | +STR_ENUM_MEMBER_RE = re.compile(r"^(\s+)([a-z_][\w]*)\s*=\s*(.+?)\s*$") |
| 48 | + |
| 49 | + |
| 50 | +def inject(source: str) -> tuple[str, int]: |
| 51 | + """Return (new_source, count_of_classes_modified).""" |
| 52 | + if "from pydantic import" in source and "ConfigDict" not in source: |
| 53 | + source = source.replace( |
| 54 | + "from pydantic import", |
| 55 | + "from pydantic import ConfigDict, ", |
| 56 | + 1, |
| 57 | + ) |
| 58 | + source = source.replace("ConfigDict, ConfigDict, ", "ConfigDict, ", 1) |
| 59 | + |
| 60 | + lines = source.splitlines(keepends=True) |
| 61 | + out: list[str] = [] |
| 62 | + i = 0 |
| 63 | + modified = 0 |
| 64 | + in_str_enum = False |
| 65 | + while i < len(lines): |
| 66 | + line = lines[i] |
| 67 | + # Handle StrEnum-member collisions before the BaseModel pass below. |
| 68 | + # We track whether we're inside a StrEnum body and patch any member |
| 69 | + # whose name shadows an inherited str method. |
| 70 | + if STR_ENUM_RE.match(line.rstrip("\n")): |
| 71 | + in_str_enum = True |
| 72 | + out.append(line) |
| 73 | + i += 1 |
| 74 | + continue |
| 75 | + if in_str_enum: |
| 76 | + stripped = line.lstrip() |
| 77 | + # End of class body: dedented non-blank line. |
| 78 | + if stripped and not line.startswith((" ", "\t")): |
| 79 | + in_str_enum = False |
| 80 | + else: |
| 81 | + m_member = STR_ENUM_MEMBER_RE.match(line.rstrip("\n")) |
| 82 | + if m_member and m_member.group(2) in STR_ENUM_COLLISIONS: |
| 83 | + code = STR_ENUM_COLLISIONS[m_member.group(2)] |
| 84 | + if "type: ignore" not in line: |
| 85 | + line = line.rstrip("\n") + f" # type: ignore[{code}]\n" |
| 86 | + modified += 1 |
| 87 | + out.append(line) |
| 88 | + i += 1 |
| 89 | + continue |
| 90 | + |
| 91 | + out.append(line) |
| 92 | + m = CLASS_RE.match(line.rstrip("\n")) |
| 93 | + if not m: |
| 94 | + i += 1 |
| 95 | + continue |
| 96 | + # Look at the very next line. If it's already model_config or pass, |
| 97 | + # leave the class alone (idempotency / empty class). |
| 98 | + next_idx = i + 1 |
| 99 | + next_line = lines[next_idx] if next_idx < len(lines) else "" |
| 100 | + if "model_config" in next_line: |
| 101 | + i += 1 |
| 102 | + continue |
| 103 | + # Replace bare `pass` (empty class body) with model_config. Use |
| 104 | + # exact match (NOT startswith) — fields like `passed: Annotated[...]` |
| 105 | + # also start with "pass" but are not empty class markers. |
| 106 | + if next_line.strip() in ("pass", "pass\n"): |
| 107 | + out.append(CONFIG_LINE + "\n") |
| 108 | + i += 2 # skip the pass |
| 109 | + modified += 1 |
| 110 | + continue |
| 111 | + out.append(CONFIG_LINE + "\n") |
| 112 | + modified += 1 |
| 113 | + i += 1 |
| 114 | + return "".join(out), modified |
| 115 | + |
| 116 | + |
| 117 | +def main() -> int: |
| 118 | + if len(sys.argv) != 2: |
| 119 | + print("usage: inject_strict_config.py <path-to-_generated.py>", file=sys.stderr) |
| 120 | + return 1 |
| 121 | + path = Path(sys.argv[1]) |
| 122 | + if not path.exists(): |
| 123 | + print(f"error: file not found: {path}", file=sys.stderr) |
| 124 | + return 1 |
| 125 | + src = path.read_text() |
| 126 | + new_src, modified = inject(src) |
| 127 | + if new_src != src: |
| 128 | + path.write_text(new_src) |
| 129 | + print(f"inject_strict_config: patched {modified} class(es) in {path}") |
| 130 | + return 0 |
| 131 | + |
| 132 | + |
| 133 | +if __name__ == "__main__": |
| 134 | + sys.exit(main()) |
0 commit comments