Skip to content

Commit 4f54611

Browse files
ctruedenclaude
andcommitted
Rewrite type references in stubs to include prefix
Stubs are placed in the scyjava.types namespace by default, but have imports with a toplevel namespace; for example, when generating stubs for org.scijava:scijava-common, the file: src/scijava/types/org/scijava/__init__.pyi would declare imports: import java.lang import java.util import org.scijava.annotations import org.scijava.app import org.scijava.cache import org.scijava.command and so on. But these imports yield errors when browsing in an IDE, and prevent the type checker from resolving all the types properly. This commit makes the following changes to address the problem: - Add a python_package_prefix parameter to generate_stubs() function - Add a _rewrite_stub_imports() function that: - Rewrites import foo.bar.X → import {python_package_prefix}.foo.bar.X - Rewrites type references similarly to have the python_package_prefix - Update the CLI to automatically pass python_package_prefix="scyjava.types" when using the default location - If --output-python-path gives a different prefix, use that instead - Add a new test_stubgen_type_references to validate rewriting behavior So then the following example above gets rewritten to be: import java.lang import java.util import scyjava.stubs.org.scijava.annotations import scyjava.stubs.org.scijava.app import scyjava.stubs.org.scijava.cache import scyjava.stubs.org.scijava.command Note that the java.lang and java.util imports are not rewritten, because they are not among the classes whose stubs are being generated. With this change, IDEs now autocomplete expressions involving these Java classes correctly, even when chained; for example: import scyjava.stubs.org.scijava.Context Context().getServiceIndex().getA shows getAll as a valid completion, because getAll() is a member method of a supertype of ServiceIndex, the class returned by getServiceIndex(). Before this patch, such chained type completions did not work. Co-authored-by: Claude <noreply@anthropic.com>
1 parent ed0add6 commit 4f54611

File tree

3 files changed

+222
-4
lines changed

3 files changed

+222
-4
lines changed

src/scyjava/_stubs/_cli.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def main() -> None:
102102
if not output_dir.exists():
103103
output_dir.mkdir(parents=True, exist_ok=True)
104104

105+
# Determine the Python package prefix for import rewriting
106+
python_package_prefix = args.output_python_path or _derive_python_prefix(
107+
args.output_dir
108+
)
109+
105110
generate_stubs(
106111
endpoints=args.endpoints,
107112
prefixes=args.prefix,
@@ -110,9 +115,23 @@ def main() -> None:
110115
include_javadoc=args.with_javadoc,
111116
add_runtime_imports=args.runtime_imports,
112117
remove_namespace_only_stubs=args.remove_namespace_only_stubs,
118+
python_package_prefix=python_package_prefix,
113119
)
114120

115121

122+
def _derive_python_prefix(output_dir: str | None) -> str:
123+
"""Derive the Python package prefix from the output directory.
124+
125+
If output_dir is None, defaults to 'scyjava.types'.
126+
"""
127+
if output_dir:
128+
# For a filesystem path, we can't reliably derive the Python prefix
129+
# Return empty string to skip import rewriting
130+
return ""
131+
# Default case: stubs go to scyjava.types
132+
return "scyjava.types"
133+
134+
116135
def _get_ouput_dir(output_dir: str | None, python_path: str | None) -> Path:
117136
if out_dir := output_dir:
118137
return Path(out_dir)

src/scyjava/_stubs/_genstubs.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def generate_stubs(
4141
include_javadoc: bool = True,
4242
add_runtime_imports: bool = True,
4343
remove_namespace_only_stubs: bool = False,
44+
python_package_prefix: str = "",
4445
) -> None:
4546
"""Generate stubs for the given maven endpoints.
4647
@@ -82,6 +83,11 @@ def generate_stubs(
8283
merge the generated stubs with other stubs in the same namespace. Without this,
8384
the `__init__.pyi` for any given module will be whatever whatever the *last*
8485
stub generator wrote to it (and therefore inaccurate).
86+
python_package_prefix : str, optional
87+
The Python package prefix under which stubs are being installed. For example,
88+
if stubs are being installed to `scyjava.types.org.scijava...`, this should be
89+
"scyjava.types". This is used to rewrite imports in the stub files so that
90+
type checkers can properly resolve cross-references. Defaults to "".
8591
"""
8692
try:
8793
import stubgenj
@@ -136,10 +142,20 @@ def _patched_start(*args: Any, **kwargs: Any) -> None:
136142
)
137143

138144
output_dir = Path(output_dir)
145+
if python_package_prefix:
146+
logger.info(
147+
"Rewriting stub imports with Python package prefix: %s",
148+
python_package_prefix,
149+
)
150+
139151
if add_runtime_imports:
140152
logger.info("Adding runtime imports to generated stubs")
141153

142154
for stub in output_dir.rglob("*.pyi"):
155+
# Rewrite imports if a Python package prefix was specified
156+
if python_package_prefix:
157+
_rewrite_stub_imports(stub, python_package_prefix)
158+
143159
stub_ast = ast.parse(stub.read_text())
144160
members = {node.name for node in stub_ast.body if hasattr(node, "name")}
145161
if members == {"__module_protocol__"}:
@@ -178,6 +194,113 @@ def _patched_start(*args: Any, **kwargs: Any) -> None:
178194
"""
179195

180196

197+
def _rewrite_stub_imports(stub_path: Path, python_package_prefix: str) -> None:
198+
"""Rewrite imports in a stub file to use the full Python package path.
199+
200+
When stubs are generated into a subdirectory like scyjava/types, they need to have
201+
their imports rewritten so that type checkers can resolve cross-references. This
202+
function transforms imports like:
203+
204+
import org.scijava.object
205+
206+
into:
207+
208+
import scyjava.types.org.scijava.object
209+
210+
and transforms type references like:
211+
212+
org.scijava.object.ObjectIndex
213+
214+
into:
215+
216+
scyjava.types.org.scijava.object.ObjectIndex
217+
"""
218+
import re
219+
220+
content = stub_path.read_text()
221+
222+
# Split into lines for import processing
223+
lines = content.split("\n")
224+
new_lines = []
225+
import_patterns = [] # Patterns to replace in annotations
226+
227+
i = 0
228+
while i < len(lines):
229+
line = lines[i]
230+
stripped = line.strip()
231+
232+
# Handle import statements
233+
if stripped.startswith("import ") and (
234+
"org." in stripped or "java." in stripped
235+
):
236+
# Parse "import org.scijava.service"
237+
match = stripped.split()
238+
if len(match) >= 2:
239+
module_name = match[1]
240+
if (
241+
not module_name.startswith(".")
242+
and "scyjava.types" not in module_name
243+
):
244+
# Only rewrite org.* imports (not java.*)
245+
if module_name.startswith("org."):
246+
new_module = f"{python_package_prefix}.{module_name}"
247+
# Preserve indentation
248+
indent = line[: len(line) - len(line.lstrip())]
249+
new_lines.append(f"{indent}import {new_module}")
250+
# Record this pattern for later annotation rewriting
251+
import_patterns.append((module_name, new_module))
252+
i += 1
253+
continue
254+
255+
# Handle "from X import Y" statements
256+
elif stripped.startswith("from ") and (" import " in stripped):
257+
if "org." in stripped or "java." in stripped:
258+
parts = stripped.split(" import ")
259+
if len(parts) == 2:
260+
module_part = parts[0].replace("from ", "").strip()
261+
imports_part = parts[1].strip()
262+
263+
if (
264+
not module_part.startswith(".")
265+
and "scyjava.types" not in module_part
266+
):
267+
if module_part.startswith("org."):
268+
new_module = f"{python_package_prefix}.{module_part}"
269+
indent = line[: len(line) - len(line.lstrip())]
270+
new_lines.append(
271+
f"{indent}from {new_module} import {imports_part}"
272+
)
273+
import_patterns.append((module_part, new_module))
274+
i += 1
275+
continue
276+
277+
new_lines.append(line)
278+
i += 1
279+
280+
# Reconstruct content with rewritten imports
281+
new_content = "\n".join(new_lines)
282+
283+
# Now rewrite type annotations that reference org.* packages
284+
# Only replace in type hints, not in already-rewritten import statements
285+
# We do this by replacing the old module names with new ones, but being careful
286+
# to not double-replace
287+
for old_prefix, new_prefix in import_patterns:
288+
# Replace qualified names like "org.scijava.service.ServiceIndex"
289+
# but NOT names that already have the prefix
290+
# Pattern: old_prefix followed by a dot and word characters
291+
# Use negative lookbehind to avoid replacing if already prefixed
292+
pattern = (
293+
r"(?<!"
294+
+ re.escape(python_package_prefix)
295+
+ r"\.)"
296+
+ re.escape(old_prefix)
297+
+ r"(?=\.\w)"
298+
)
299+
new_content = re.sub(pattern, new_prefix, new_content)
300+
301+
stub_path.write_text(new_content)
302+
303+
181304
def ruff_check(output: Path, select: str = "E,W,F,I,UP,C4,B,RUF,TC,TID") -> None:
182305
"""Run ruff check and format on the generated stubs."""
183306
if not shutil.which("ruff"):

tests/test_stubgen.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import ast
34
import sys
4-
from typing import TYPE_CHECKING
5+
from pathlib import Path
56
from unittest.mock import patch
67

78
import jpype
@@ -10,9 +11,6 @@
1011
import scyjava
1112
from scyjava._stubs import _cli
1213

13-
if TYPE_CHECKING:
14-
from pathlib import Path
15-
1614

1715
@pytest.mark.skipif(
1816
scyjava.config.mode != scyjava.config.Mode.JPYPE,
@@ -56,3 +54,81 @@ def test_stubgen(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
5654
func = Function(1)
5755
mock_start_jvm.assert_called_once()
5856
assert isinstance(func, jpype.JObject)
57+
58+
59+
@pytest.mark.skipif(
60+
scyjava.config.mode != scyjava.config.Mode.JPYPE,
61+
reason="Stubgen not supported in JEP",
62+
)
63+
def test_stubgen_type_references(
64+
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
65+
) -> None:
66+
"""Test that generated stubs have properly qualified type references.
67+
68+
This validates that when stubs are generated with a Python package prefix,
69+
all type references are properly rewritten so type checkers can resolve them.
70+
"""
71+
import tempfile
72+
73+
# Generate stubs with --output-python-path so a prefix is used
74+
# (rather than --output-dir which doesn't imply a Python module path)
75+
stubs_module = "test_stubs"
76+
77+
# Create a temporary directory and add it to sys.path
78+
with tempfile.TemporaryDirectory() as tmpdir_str:
79+
tmpdir = Path(tmpdir_str)
80+
original_path = sys.path.copy()
81+
82+
try:
83+
sys.path.insert(0, str(tmpdir))
84+
85+
# Create the parent module package
86+
stubs_pkg = tmpdir / stubs_module
87+
stubs_pkg.mkdir()
88+
(stubs_pkg / "__init__.py").touch()
89+
90+
monkeypatch.setattr(
91+
sys,
92+
"argv",
93+
[
94+
"scyjava-stubgen",
95+
"org.scijava:parsington:3.1.0",
96+
"--output-python-path",
97+
stubs_module,
98+
],
99+
)
100+
_cli.main()
101+
102+
# Check that import statements were rewritten with the prefix
103+
init_stub = stubs_pkg / "org" / "scijava" / "parsington" / "__init__.pyi"
104+
assert init_stub.exists(), f"Expected stub file {init_stub} not found"
105+
106+
content = init_stub.read_text()
107+
stub_ast = ast.parse(content)
108+
109+
# Find all Import and ImportFrom nodes
110+
imports = [
111+
node
112+
for node in ast.walk(stub_ast)
113+
if isinstance(node, (ast.Import, ast.ImportFrom))
114+
]
115+
116+
# Collect imported module names
117+
imported_modules = set()
118+
for imp in imports:
119+
if isinstance(imp, ast.Import):
120+
for alias in imp.names:
121+
imported_modules.add(alias.name)
122+
elif isinstance(imp, ast.ImportFrom) and imp.module:
123+
imported_modules.add(imp.module)
124+
125+
# Verify that bare org.scijava.* imports don't exist (they should be prefixed)
126+
org_imports = {
127+
m for m in imported_modules if m and m.startswith("org.scijava.")
128+
}
129+
assert not org_imports, (
130+
f"Found unrewritten org.scijava imports in {init_stub}: {org_imports}. "
131+
f"These should have been prefixed with '{stubs_module}.'"
132+
)
133+
finally:
134+
sys.path = original_path

0 commit comments

Comments
 (0)