Skip to content

Commit f2db975

Browse files
authored
[mypyc] Allow primitives to depend on optional C source files (#20379)
This way we don't need to compile all support rt library code every time mypyc is run. Some primitives are not used very often, but may have relatively complex implementations. Update `bytes.translate` as an example.
1 parent 88cf8e0 commit f2db975

File tree

20 files changed

+269
-118
lines changed

20 files changed

+269
-118
lines changed

mypyc/analysis/capsule_deps.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
from __future__ import annotations
22

3+
from mypyc.ir.deps import Dependency
34
from mypyc.ir.func_ir import FuncIR
45
from mypyc.ir.ops import CallC, PrimitiveOp
56

67

7-
def find_implicit_capsule_dependencies(fn: FuncIR) -> set[str] | None:
8-
"""Find implicit dependencies on capsules that need to be imported.
8+
def find_implicit_op_dependencies(fn: FuncIR) -> set[Dependency] | None:
9+
"""Find implicit dependencies that need to be imported.
910
1011
Using primitives or types defined in librt submodules such as "librt.base64"
11-
requires a capsule import.
12+
requires dependency imports (e.g., capsule imports).
1213
1314
Note that a module can depend on a librt module even if it doesn't explicitly
1415
import it, for example via re-exported names or via return types of functions
1516
defined in other modules.
1617
"""
17-
deps: set[str] | None = None
18+
deps: set[Dependency] | None = None
1819
for block in fn.blocks:
1920
for op in block.ops:
2021
# TODO: Also determine implicit type object dependencies (e.g. cast targets)
21-
if isinstance(op, CallC) and op.capsule is not None:
22-
if deps is None:
23-
deps = set()
24-
deps.add(op.capsule)
22+
if isinstance(op, CallC) and op.dependencies is not None:
23+
for dep in op.dependencies:
24+
if deps is None:
25+
deps = set()
26+
deps.add(dep)
2527
else:
2628
assert not isinstance(op, PrimitiveOp), "Lowered IR is expected"
2729
return deps

mypyc/build.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from mypyc.codegen import emitmodule
4040
from mypyc.common import IS_FREE_THREADED, RUNTIME_C_FILES, shared_lib_name
4141
from mypyc.errors import Errors
42+
from mypyc.ir.deps import SourceDep
4243
from mypyc.ir.pprint import format_modules
4344
from mypyc.namegen import exported_name
4445
from mypyc.options import CompilerOptions
@@ -282,14 +283,14 @@ def generate_c(
282283
groups: emitmodule.Groups,
283284
fscache: FileSystemCache,
284285
compiler_options: CompilerOptions,
285-
) -> tuple[list[list[tuple[str, str]]], str]:
286+
) -> tuple[list[list[tuple[str, str]]], str, list[SourceDep]]:
286287
"""Drive the actual core compilation step.
287288
288289
The groups argument describes how modules are assigned to C
289290
extension modules. See the comments on the Groups type in
290291
mypyc.emitmodule for details.
291292
292-
Returns the C source code and (for debugging) the pretty printed IR.
293+
Returns the C source code, (for debugging) the pretty printed IR, and list of SourceDeps.
293294
"""
294295
t0 = time.time()
295296

@@ -325,7 +326,10 @@ def generate_c(
325326
if options.mypyc_annotation_file:
326327
generate_annotated_html(options.mypyc_annotation_file, result, modules, mapper)
327328

328-
return ctext, "\n".join(format_modules(modules))
329+
# Collect SourceDep dependencies
330+
source_deps = sorted(emitmodule.collect_source_dependencies(modules), key=lambda d: d.path)
331+
332+
return ctext, "\n".join(format_modules(modules)), source_deps
329333

330334

331335
def build_using_shared_lib(
@@ -486,9 +490,9 @@ def mypyc_build(
486490
*,
487491
separate: bool | list[tuple[list[str], str | None]] = False,
488492
only_compile_paths: Iterable[str] | None = None,
489-
skip_cgen_input: Any | None = None,
493+
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
490494
always_use_shared_lib: bool = False,
491-
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]]]:
495+
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]:
492496
"""Do the front and middle end of mypyc building, producing and writing out C source."""
493497
fscache = FileSystemCache()
494498
mypyc_sources, all_sources, options = get_mypy_config(
@@ -511,14 +515,16 @@ def mypyc_build(
511515

512516
# We let the test harness just pass in the c file contents instead
513517
# so that it can do a corner-cutting version without full stubs.
518+
source_deps: list[SourceDep] = []
514519
if not skip_cgen_input:
515-
group_cfiles, ops_text = generate_c(
520+
group_cfiles, ops_text, source_deps = generate_c(
516521
all_sources, options, groups, fscache, compiler_options=compiler_options
517522
)
518523
# TODO: unique names?
519524
write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text)
520525
else:
521-
group_cfiles = skip_cgen_input
526+
group_cfiles = skip_cgen_input[0]
527+
source_deps = [SourceDep(d) for d in skip_cgen_input[1]]
522528

523529
# Write out the generated C and collect the files for each group
524530
# Should this be here??
@@ -535,7 +541,7 @@ def mypyc_build(
535541
deps = [os.path.join(compiler_options.target_dir, dep) for dep in get_header_deps(cfiles)]
536542
group_cfilenames.append((cfilenames, deps))
537543

538-
return groups, group_cfilenames
544+
return groups, group_cfilenames, source_deps
539545

540546

541547
def mypycify(
@@ -548,7 +554,7 @@ def mypycify(
548554
strip_asserts: bool = False,
549555
multi_file: bool = False,
550556
separate: bool | list[tuple[list[str], str | None]] = False,
551-
skip_cgen_input: Any | None = None,
557+
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
552558
target_dir: str | None = None,
553559
include_runtime_files: bool | None = None,
554560
strict_dunder_typing: bool = False,
@@ -633,7 +639,7 @@ def mypycify(
633639
)
634640

635641
# Generate all the actual important C code
636-
groups, group_cfilenames = mypyc_build(
642+
groups, group_cfilenames, source_deps = mypyc_build(
637643
paths,
638644
only_compile_paths=only_compile_paths,
639645
compiler_options=compiler_options,
@@ -708,11 +714,19 @@ def mypycify(
708714
# compiler invocations.
709715
shared_cfilenames = []
710716
if not compiler_options.include_runtime_files:
711-
for name in RUNTIME_C_FILES:
717+
# Collect all files to copy: runtime files + conditional source files
718+
files_to_copy = list(RUNTIME_C_FILES)
719+
for source_dep in source_deps:
720+
files_to_copy.append(source_dep.path)
721+
files_to_copy.append(source_dep.get_header())
722+
723+
# Copy all files
724+
for name in files_to_copy:
712725
rt_file = os.path.join(build_dir, name)
713726
with open(os.path.join(include_dir(), name), encoding="utf-8") as f:
714727
write_file(rt_file, f.read())
715-
shared_cfilenames.append(rt_file)
728+
if name.endswith(".c"):
729+
shared_cfilenames.append(rt_file)
716730

717731
extensions = []
718732
for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames):

mypyc/codegen/emitmodule.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from mypy.options import Options
2828
from mypy.plugin import Plugin, ReportConfigContext
2929
from mypy.util import hash_digest, json_dumps
30-
from mypyc.analysis.capsule_deps import find_implicit_capsule_dependencies
30+
from mypyc.analysis.capsule_deps import find_implicit_op_dependencies
3131
from mypyc.codegen.cstring import c_string_initializer
3232
from mypyc.codegen.emit import (
3333
Emitter,
@@ -56,6 +56,7 @@
5656
short_id_from_name,
5757
)
5858
from mypyc.errors import Errors
59+
from mypyc.ir.deps import LIBRT_BASE64, LIBRT_STRINGS, SourceDep
5960
from mypyc.ir.func_ir import FuncIR
6061
from mypyc.ir.module_ir import ModuleIR, ModuleIRs, deserialize_modules
6162
from mypyc.ir.ops import DeserMaps, LoadLiteral
@@ -263,9 +264,9 @@ def compile_scc_to_ir(
263264
# Switch to lower abstraction level IR.
264265
lower_ir(fn, compiler_options)
265266
# Calculate implicit module dependencies (needed for librt)
266-
capsules = find_implicit_capsule_dependencies(fn)
267-
if capsules is not None:
268-
module.capsules.update(capsules)
267+
deps = find_implicit_op_dependencies(fn)
268+
if deps is not None:
269+
module.dependencies.update(deps)
269270
# Perform optimizations.
270271
do_copy_propagation(fn, compiler_options)
271272
do_flag_elimination(fn, compiler_options)
@@ -427,6 +428,16 @@ def load_scc_from_cache(
427428
return modules
428429

429430

431+
def collect_source_dependencies(modules: dict[str, ModuleIR]) -> set[SourceDep]:
432+
"""Collect all SourceDep dependencies from all modules."""
433+
source_deps: set[SourceDep] = set()
434+
for module in modules.values():
435+
for dep in module.dependencies:
436+
if isinstance(dep, SourceDep):
437+
source_deps.add(dep)
438+
return source_deps
439+
440+
430441
def compile_modules_to_c(
431442
result: BuildResult, compiler_options: CompilerOptions, errors: Errors, groups: Groups
432443
) -> tuple[ModuleIRs, list[FileContents], Mapper]:
@@ -560,6 +571,10 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
560571
if self.compiler_options.include_runtime_files:
561572
for name in RUNTIME_C_FILES:
562573
base_emitter.emit_line(f'#include "{name}"')
574+
# Include conditional source files
575+
source_deps = collect_source_dependencies(self.modules)
576+
for source_dep in sorted(source_deps, key=lambda d: d.path):
577+
base_emitter.emit_line(f'#include "{source_dep.path}"')
563578
base_emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
564579
base_emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"')
565580
emitter = base_emitter
@@ -611,10 +626,14 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
611626
ext_declarations.emit_line("#include <CPy.h>")
612627
if self.compiler_options.depends_on_librt_internal:
613628
ext_declarations.emit_line("#include <librt_internal.h>")
614-
if any("librt.base64" in mod.capsules for mod in self.modules.values()):
629+
if any(LIBRT_BASE64 in mod.dependencies for mod in self.modules.values()):
615630
ext_declarations.emit_line("#include <librt_base64.h>")
616-
if any("librt.strings" in mod.capsules for mod in self.modules.values()):
631+
if any(LIBRT_STRINGS in mod.dependencies for mod in self.modules.values()):
617632
ext_declarations.emit_line("#include <librt_strings.h>")
633+
# Include headers for conditional source files
634+
source_deps = collect_source_dependencies(self.modules)
635+
for source_dep in sorted(source_deps, key=lambda d: d.path):
636+
ext_declarations.emit_line(f'#include "{source_dep.get_header()}"')
618637

619638
declarations = Emitter(self.context)
620639
declarations.emit_line(f"#ifndef MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
@@ -1072,11 +1091,11 @@ def emit_module_exec_func(
10721091
emitter.emit_line("if (import_librt_internal() < 0) {")
10731092
emitter.emit_line("return -1;")
10741093
emitter.emit_line("}")
1075-
if "librt.base64" in module.capsules:
1094+
if LIBRT_BASE64 in module.dependencies:
10761095
emitter.emit_line("if (import_librt_base64() < 0) {")
10771096
emitter.emit_line("return -1;")
10781097
emitter.emit_line("}")
1079-
if "librt.strings" in module.capsules:
1098+
if LIBRT_STRINGS in module.dependencies:
10801099
emitter.emit_line("if (import_librt_strings() < 0) {")
10811100
emitter.emit_line("return -1;")
10821101
emitter.emit_line("}")

mypyc/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@
6565
BITMAP_TYPE: Final = "uint32_t"
6666
BITMAP_BITS: Final = 32
6767

68-
# Runtime C library files
68+
# Runtime C library files that are always included (some ops may bring
69+
# extra dependencies via mypyc.ir.SourceDep)
6970
RUNTIME_C_FILES: Final = [
7071
"init.c",
7172
"getargs.c",

mypyc/ir/deps.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Final
2+
3+
4+
class Capsule:
5+
"""Defines a C extension capsule that a primitive may require."""
6+
7+
def __init__(self, name: str) -> None:
8+
# Module fullname, e.g. 'librt.base64'
9+
self.name: Final = name
10+
11+
def __repr__(self) -> str:
12+
return f"Capsule(name={self.name!r})"
13+
14+
def __eq__(self, other: object) -> bool:
15+
return isinstance(other, Capsule) and self.name == other.name
16+
17+
def __hash__(self) -> int:
18+
return hash(("Capsule", self.name))
19+
20+
21+
class SourceDep:
22+
"""Defines a C source file that a primitive may require.
23+
24+
Each source file must also have a corresponding .h file (replace .c with .h)
25+
that gets implicitly #included if the source is used.
26+
"""
27+
28+
def __init__(self, path: str) -> None:
29+
# Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c'
30+
self.path: Final = path
31+
32+
def __repr__(self) -> str:
33+
return f"SourceDep(path={self.path!r})"
34+
35+
def __eq__(self, other: object) -> bool:
36+
return isinstance(other, SourceDep) and self.path == other.path
37+
38+
def __hash__(self) -> int:
39+
return hash(("SourceDep", self.path))
40+
41+
def get_header(self) -> str:
42+
"""Get the header file path by replacing .c with .h"""
43+
return self.path.replace(".c", ".h")
44+
45+
46+
Dependency = Capsule | SourceDep
47+
48+
49+
LIBRT_STRINGS: Final = Capsule("librt.strings")
50+
LIBRT_BASE64: Final = Capsule("librt.base64")
51+
52+
BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")

mypyc/ir/module_ir.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from mypyc.common import JsonDict
66
from mypyc.ir.class_ir import ClassIR
7+
from mypyc.ir.deps import Capsule, Dependency, SourceDep
78
from mypyc.ir.func_ir import FuncDecl, FuncIR
89
from mypyc.ir.ops import DeserMaps
910
from mypyc.ir.rtypes import RType, deserialize_type
@@ -30,17 +31,25 @@ def __init__(
3031
# These are only visible in the module that defined them, so no need
3132
# to serialize.
3233
self.type_var_names = type_var_names
33-
# Capsules needed by the module, specified via module names such as "librt.base64"
34-
self.capsules: set[str] = set()
34+
# Dependencies needed by the module (such as capsules or source files)
35+
self.dependencies: set[Dependency] = set()
3536

3637
def serialize(self) -> JsonDict:
38+
# Serialize dependencies as a list of dicts with type information
39+
serialized_deps = []
40+
for dep in sorted(self.dependencies, key=lambda d: (type(d).__name__, str(d))):
41+
if isinstance(dep, Capsule):
42+
serialized_deps.append({"type": "Capsule", "name": dep.name})
43+
elif isinstance(dep, SourceDep):
44+
serialized_deps.append({"type": "SourceDep", "path": dep.path})
45+
3746
return {
3847
"fullname": self.fullname,
3948
"imports": self.imports,
4049
"functions": [f.serialize() for f in self.functions],
4150
"classes": [c.serialize() for c in self.classes],
4251
"final_names": [(k, t.serialize()) for k, t in self.final_names],
43-
"capsules": sorted(self.capsules),
52+
"dependencies": serialized_deps,
4453
}
4554

4655
@classmethod
@@ -53,7 +62,16 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR:
5362
[(k, deserialize_type(t, ctx)) for k, t in data["final_names"]],
5463
[],
5564
)
56-
module.capsules = set(data["capsules"])
65+
66+
# Deserialize dependencies
67+
deps: set[Dependency] = set()
68+
for dep_dict in data["dependencies"]:
69+
if dep_dict["type"] == "Capsule":
70+
deps.add(Capsule(dep_dict["name"]))
71+
elif dep_dict["type"] == "SourceDep":
72+
deps.add(SourceDep(dep_dict["path"]))
73+
module.dependencies = deps
74+
5775
return module
5876

5977

0 commit comments

Comments
 (0)