Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,14 @@ def __init__(

self.visitor = visitor

# Class body context: tracks ClassVar names defined so far when processing
# a class body, so that intra-class references (e.g. C = A | B where A is
# a ClassVar defined earlier in the same class) can be resolved correctly.
# Without this, mypyc looks up such names in module globals, which fails.
self.class_body_classvars: dict[str, None] = {}
self.class_body_obj: Value | None = None
self.class_body_is_ext: bool = False

# This list operates similarly to a function call stack for nested functions. Whenever a
# function definition begins to be generated, a FuncInfo instance is added to the stack,
# and information about that function (e.g. whether it is nested, its environment class to
Expand Down
32 changes: 32 additions & 0 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
else:
cls_builder = NonExtClassBuilder(builder, cdef)

# Set up class body context so that intra-class ClassVar references
# (e.g. C = A | B where A is defined earlier in the same class) can be
# resolved from the class being built instead of module globals.
saved_classvars = builder.class_body_classvars
saved_obj = builder.class_body_obj
saved_is_ext = builder.class_body_is_ext
builder.class_body_classvars = {}
builder.class_body_obj = cls_builder.class_body_obj()
builder.class_body_is_ext = ir.is_ext_class

for stmt in cdef.defs.body:
if (
isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef))
Expand Down Expand Up @@ -179,13 +189,21 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
# We want to collect class variables in a dictionary for both real
# non-extension classes and fake dataclass ones.
cls_builder.add_attr(lvalue, stmt)
# Track this ClassVar so subsequent class body statements can reference it.
if is_class_var(lvalue) or stmt.is_final_def:
builder.class_body_classvars[lvalue.name] = None

elif isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr):
# Docstring. Ignore
pass
else:
builder.error("Unsupported statement in class body", stmt.line)

# Restore previous class body context (handles nested classes).
builder.class_body_classvars = saved_classvars
builder.class_body_obj = saved_obj
builder.class_body_is_ext = saved_is_ext

# Generate implicit property setters/getters
for name, decl in ir.method_decls.items():
if decl.implicit and decl.is_prop_getter:
Expand Down Expand Up @@ -232,12 +250,23 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
def finalize(self, ir: ClassIR) -> None:
"""Perform any final operations to complete the class IR"""

def class_body_obj(self) -> Value | None:
"""Return the object to use for loading class attributes during class body init.

For extension classes, this is the type object. For non-extension classes,
this is the class dict. Returns None if not applicable.
"""
return None


class NonExtClassBuilder(ClassBuilder):
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
super().__init__(builder, cdef)
self.non_ext = self.create_non_ext_info()

def class_body_obj(self) -> Value | None:
return self.non_ext.dict

def create_non_ext_info(self) -> NonExtClassInfo:
non_ext_bases = populate_non_ext_bases(self.builder, self.cdef)
non_ext_metaclass = find_non_ext_metaclass(self.builder, self.cdef, non_ext_bases)
Expand Down Expand Up @@ -293,6 +322,9 @@ def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
# If the class is not decorated, generate an extension class for it.
self.type_obj: Value = allocate_class(builder, cdef)

def class_body_obj(self) -> Value | None:
return self.type_obj

def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool:
"""Controls whether to skip generating a default for an attribute."""
return False
Expand Down
11 changes: 11 additions & 0 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
else:
return builder.read(builder.get_assignment_target(expr, for_read=True), expr.line)

# If we're evaluating a class body and this name is a ClassVar defined earlier
# in the same class, load it from the class being built (type object for ext classes,
# class dict for non-ext classes) instead of module globals.
if builder.class_body_obj is not None and expr.name in builder.class_body_classvars:
if builder.class_body_is_ext:
return builder.py_get_attr(builder.class_body_obj, expr.name, expr.line)
else:
return builder.primitive_op(
dict_get_item_op, [builder.class_body_obj, builder.load_str(expr.name)], expr.line
)

return builder.load_global(expr)


Expand Down
76 changes: 76 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -5774,3 +5774,79 @@ from native import Concrete
c = Concrete()
assert c.value() == 42
assert c.derived() == 42

[case testClassVarSelfReference]
# ClassVar initializers that reference other ClassVars from the same class.
# In CPython, the class body executes as a function where earlier assignments
# are available to later ones. mypyc must replicate this by loading from the
# class being built (type object for ext classes, class dict for non-ext)
# instead of module globals.
from typing import ClassVar, Dict, Set

class Ext:
A: ClassVar[Set[int]] = {1, 2, 3}
B: ClassVar[Set[int]] = {4, 5, 6}
C: ClassVar[Set[int]] = A | B

class ExtChained:
X: ClassVar[Set[int]] = {1, 2}
Y: ClassVar[Set[int]] = X | {3}
Z: ClassVar[Set[int]] = Y | {4}

class ExtDict:
BASE: ClassVar[Dict[str, int]] = {"a": 1, "b": 2}
EXTENDED: ClassVar[Dict[str, int]] = {**BASE, "c": 3}

class ExtSub(Ext):
E: ClassVar[Set[int]] = {7, 8}

[file driver.py]
from native import Ext, ExtChained, ExtDict, ExtSub

assert Ext.A == {1, 2, 3}
assert Ext.B == {4, 5, 6}
assert Ext.C == {1, 2, 3, 4, 5, 6}

assert ExtChained.X == {1, 2}
assert ExtChained.Y == {1, 2, 3}
assert ExtChained.Z == {1, 2, 3, 4}

assert ExtDict.BASE == {"a": 1, "b": 2}
assert ExtDict.EXTENDED == {"a": 1, "b": 2, "c": 3}

assert ExtSub.C == {1, 2, 3, 4, 5, 6}
assert ExtSub.E == {7, 8}

[case testClassVarSelfReferenceNonExt]
# Same as testClassVarSelfReference but for non-extension classes.
from typing import ClassVar, Dict, Set
from mypy_extensions import mypyc_attr

@mypyc_attr(allow_interpreted_subclasses=True)
class NonExt:
A: ClassVar[Set[str]] = {"a", "b"}
B: ClassVar[Set[str]] = {"c"}
C: ClassVar[Set[str]] = A | B

@mypyc_attr(allow_interpreted_subclasses=True)
class NonExtDict:
BASE: ClassVar[Dict[str, int]] = {"x": 1}
EXTENDED: ClassVar[Dict[str, int]] = {**BASE, "y": 2}

@mypyc_attr(allow_interpreted_subclasses=True)
class NonExtChained:
X: ClassVar[Set[int]] = {10}
Y: ClassVar[Set[int]] = X | {20}
Z: ClassVar[Set[int]] = Y | {30}

[file driver.py]
from native import NonExt, NonExtDict, NonExtChained

assert NonExt.A == {"a", "b"}
assert NonExt.B == {"c"}
assert NonExt.C == {"a", "b", "c"}

assert NonExtDict.BASE == {"x": 1}
assert NonExtDict.EXTENDED == {"x": 1, "y": 2}

assert NonExtChained.Z == {10, 20, 30}
Loading