diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 6018634b4c48..eeea2f39f849 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -233,6 +233,14 @@ def __init__( self.visitor = visitor + # Class body context: tracks ClassVar names defined so far when processing + # a class body, so that intra-class references (e.g. C = A | B where A is + # a ClassVar defined earlier in the same class) can be resolved correctly. + # Without this, mypyc looks up such names in module globals, which fails. + self.class_body_classvars: dict[str, None] = {} + self.class_body_obj: Value | None = None + self.class_body_ir: ClassIR | None = None + # This list operates similarly to a function call stack for nested functions. Whenever a # function definition begins to be generated, a FuncInfo instance is added to the stack, # and information about that function (e.g. whether it is nested, its environment class to diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 03b24cefb710..8b6a83a44950 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -137,6 +137,13 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: else: cls_builder = NonExtClassBuilder(builder, cdef) + # Set up class body context so that intra-class ClassVar references + # (e.g. C = A | B where A is defined earlier in the same class) can be + # resolved from the class being built instead of module globals. + builder.class_body_classvars = {} + builder.class_body_obj = cls_builder.class_body_obj() + builder.class_body_ir = ir + for stmt in cdef.defs.body: if ( isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef)) @@ -179,6 +186,9 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: # We want to collect class variables in a dictionary for both real # non-extension classes and fake dataclass ones. cls_builder.add_attr(lvalue, stmt) + # Track this ClassVar so subsequent class body statements can reference it. + if is_class_var(lvalue) or stmt.is_final_def: + builder.class_body_classvars[lvalue.name] = None elif isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr): # Docstring. Ignore @@ -186,6 +196,11 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: else: builder.error("Unsupported statement in class body", stmt.line) + # Clear class body context (nested classes are rejected above, so no need to save/restore). + builder.class_body_classvars = {} + builder.class_body_obj = None + builder.class_body_ir = None + # Generate implicit property setters/getters for name, decl in ir.method_decls.items(): if decl.implicit and decl.is_prop_getter: @@ -232,12 +247,23 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: def finalize(self, ir: ClassIR) -> None: """Perform any final operations to complete the class IR""" + def class_body_obj(self) -> Value | None: + """Return the object to use for loading class attributes during class body init. + + For extension classes, this is the type object. For non-extension classes, + this is the class dict. Returns None if not applicable. + """ + return None + class NonExtClassBuilder(ClassBuilder): def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: super().__init__(builder, cdef) self.non_ext = self.create_non_ext_info() + def class_body_obj(self) -> Value | None: + return self.non_ext.dict + def create_non_ext_info(self) -> NonExtClassInfo: non_ext_bases = populate_non_ext_bases(self.builder, self.cdef) non_ext_metaclass = find_non_ext_metaclass(self.builder, self.cdef, non_ext_bases) @@ -293,6 +319,9 @@ def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: # If the class is not decorated, generate an extension class for it. self.type_obj: Value = allocate_class(builder, cdef) + def class_body_obj(self) -> Value | None: + return self.type_obj + def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool: """Controls whether to skip generating a default for an attribute.""" return False diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index f45319b7ef2b..f1591ff9069d 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -213,6 +213,17 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: else: return builder.read(builder.get_assignment_target(expr, for_read=True), expr.line) + # If we're evaluating a class body and this name is a ClassVar defined earlier + # in the same class, load it from the class being built (type object for ext classes, + # class dict for non-ext classes) instead of module globals. + if builder.class_body_obj is not None and expr.name in builder.class_body_classvars: + if builder.class_body_ir is not None and builder.class_body_ir.is_ext_class: + return builder.py_get_attr(builder.class_body_obj, expr.name, expr.line) + else: + return builder.primitive_op( + dict_get_item_op, [builder.class_body_obj, builder.load_str(expr.name)], expr.line + ) + return builder.load_global(expr) diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index c515d057f2a0..22b8d888c168 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -2901,3 +2901,300 @@ class InvalidKwarg: @mypyc_attr(str()) # E: All "mypyc_attr" positional arguments must be string literals. class InvalidLiteral: pass + +[case testClassVarSelfReferenceExt_withgil_toplevel] +from typing import ClassVar, Set + +class Ext: + A: ClassVar[Set[int]] = {1, 2} + B: ClassVar[Set[int]] = A | {3} +[out] +def __top_level__(): + r0, r1 :: object + r2 :: bit + r3 :: str + r4, r5 :: object + r6 :: str + r7 :: dict + r8, r9 :: object + r10 :: str + r11, r12 :: object + r13, r14 :: bool + r15 :: str + r16 :: tuple + r17 :: i32 + r18 :: bit + r19 :: dict + r20 :: str + r21 :: i32 + r22 :: bit + r23 :: object + r24 :: set + r25 :: object + r26 :: i32 + r27 :: bit + r28 :: object + r29 :: i32 + r30 :: bit + r31 :: str + r32 :: i32 + r33 :: bit + r34 :: object + r35 :: str + r36 :: object + r37, r38 :: set + r39 :: object + r40 :: i32 + r41 :: bit + r42 :: object + r43 :: set + r44 :: str + r45 :: i32 + r46 :: bit + r47 :: bool +L0: + r0 = builtins :: module + r1 = load_address _Py_NoneStruct + r2 = r0 != r1 + if r2 goto L2 else goto L1 :: bool +L1: + r3 = 'builtins' + r4 = PyImport_Import(r3) + builtins = r4 :: module +L2: + r5 = ('ClassVar', 'Set') + r6 = 'typing' + r7 = __main__.globals :: static + r8 = CPyImport_ImportFromMany(r6, r5, r5, r7) + typing = r8 :: module + r9 = :: object + r10 = '__main__' + r11 = __main__.Ext_template :: type + r12 = CPyType_FromTemplate(r11, r9, r10) + r13 = Ext_trait_vtable_setup() + r14 = Ext_coroutine_setup(r12) + r15 = '__mypyc_attrs__' + r16 = CPyTuple_LoadEmptyTupleConstant() + r17 = PyObject_SetAttr(r12, r15, r16) + r18 = r17 >= 0 :: signed + __main__.Ext = r12 :: type + r19 = __main__.globals :: static + r20 = 'Ext' + r21 = PyDict_SetItem(r19, r20, r12) + r22 = r21 >= 0 :: signed + r23 = __main__.Ext :: type + r24 = PySet_New(0) + r25 = object 1 + r26 = PySet_Add(r24, r25) + r27 = r26 >= 0 :: signed + r28 = object 2 + r29 = PySet_Add(r24, r28) + r30 = r29 >= 0 :: signed + r31 = 'A' + r32 = PyObject_SetAttr(r23, r31, r24) + r33 = r32 >= 0 :: signed + r34 = __main__.Ext :: type + r35 = 'A' + r36 = CPyObject_GetAttr(r12, r35) + r37 = cast(set, r36) + r38 = PySet_New(0) + r39 = object 3 + r40 = PySet_Add(r38, r39) + r41 = r40 >= 0 :: signed + r42 = PyNumber_Or(r37, r38) + r43 = cast(set, r42) + r44 = 'B' + r45 = PyObject_SetAttr(r34, r44, r43) + r46 = r45 >= 0 :: signed + r47 = CPy_InitSubclass(r12) + return 1 + +[case testClassVarSelfReferenceNonExt_withgil_toplevel] +from typing import ClassVar, Set +from mypy_extensions import mypyc_attr + +@mypyc_attr(native_class=False) +class NonExt: + A: ClassVar[Set[str]] = {"a"} + B: ClassVar[Set[str]] = A | {"b"} +[out] +def __top_level__(): + r0, r1 :: object + r2 :: bit + r3 :: str + r4, r5 :: object + r6 :: str + r7 :: dict + r8, r9 :: object + r10 :: str + r11 :: dict + r12 :: object + r13 :: tuple + r14, r15 :: object + r16 :: str + r17 :: bool + r18, r19 :: str + r20 :: object + r21 :: object[2] + r22 :: object_ptr + r23 :: object + r24, r25, r26, r27 :: dict + r28 :: object + r29 :: str + r30 :: i32 + r31 :: bit + r32 :: str + r33 :: set + r34 :: i32 + r35 :: bit + r36 :: str + r37 :: i32 + r38 :: bit + r39 :: object + r40 :: str + r41 :: i32 + r42 :: bit + r43 :: str + r44 :: object + r45 :: set + r46 :: str + r47 :: set + r48 :: i32 + r49 :: bit + r50 :: object + r51 :: set + r52 :: str + r53 :: i32 + r54 :: bit + r55, r56 :: str + r57 :: i32 + r58 :: bit + r59, r60 :: str + r61 :: i32 + r62 :: bit + r63, r64 :: str + r65 :: i32 + r66 :: bit + r67 :: object[3] + r68 :: object_ptr + r69 :: object + r70 :: dict + r71 :: str + r72, r73 :: object + r74 :: object[1] + r75 :: object_ptr + r76, r77 :: object + r78 :: object[1] + r79 :: object_ptr + r80 :: object + r81 :: dict + r82 :: str + r83 :: i32 + r84 :: bit + r85 :: object +L0: + r0 = builtins :: module + r1 = load_address _Py_NoneStruct + r2 = r0 != r1 + if r2 goto L2 else goto L1 :: bool +L1: + r3 = 'builtins' + r4 = PyImport_Import(r3) + builtins = r4 :: module +L2: + r5 = ('ClassVar', 'Set') + r6 = 'typing' + r7 = __main__.globals :: static + r8 = CPyImport_ImportFromMany(r6, r5, r5, r7) + typing = r8 :: module + r9 = ('mypyc_attr',) + r10 = 'mypy_extensions' + r11 = __main__.globals :: static + r12 = CPyImport_ImportFromMany(r10, r9, r9, r11) + mypy_extensions = r12 :: module + r13 = CPyTuple_LoadEmptyTupleConstant() + r14 = load_address PyType_Type + r15 = CPy_CalculateMetaclass(r14, r13) + r16 = '__prepare__' + r17 = PyObject_HasAttr(r15, r16) + if r17 goto L3 else goto L4 :: bool +L3: + r18 = 'NonExt' + r19 = '__prepare__' + r20 = CPyObject_GetAttr(r15, r19) + r21 = [r18, r13] + r22 = load_address r21 + r23 = PyObject_Vectorcall(r20, r22, 2, 0) + keep_alive r18, r13 + r24 = cast(dict, r23) + r25 = r24 + goto L5 +L4: + r26 = PyDict_New() + r25 = r26 +L5: + r27 = PyDict_New() + r28 = load_address PySet_Type + r29 = 'A' + r30 = PyDict_SetItem(r27, r29, r28) + r31 = r30 >= 0 :: signed + r32 = 'a' + keep_alive r14, r13 + r33 = PySet_New(0) + r34 = PySet_Add(r33, r32) + r35 = r34 >= 0 :: signed + r36 = 'A' + r37 = CPyDict_SetItem(r25, r36, r33) + r38 = r37 >= 0 :: signed + r39 = load_address PySet_Type + r40 = 'B' + r41 = PyDict_SetItem(r27, r40, r39) + r42 = r41 >= 0 :: signed + r43 = 'A' + r44 = CPyDict_GetItem(r25, r43) + r45 = cast(set, r44) + r46 = 'b' + r47 = PySet_New(0) + r48 = PySet_Add(r47, r46) + r49 = r48 >= 0 :: signed + r50 = PyNumber_Or(r45, r47) + r51 = cast(set, r50) + r52 = 'B' + r53 = CPyDict_SetItem(r25, r52, r51) + r54 = r53 >= 0 :: signed + r55 = 'NonExt' + r56 = '__annotations__' + r57 = CPyDict_SetItem(r25, r56, r27) + r58 = r57 >= 0 :: signed + r59 = 'mypyc filler docstring' + r60 = '__doc__' + r61 = CPyDict_SetItem(r25, r60, r59) + r62 = r61 >= 0 :: signed + r63 = '__main__' + r64 = '__module__' + r65 = CPyDict_SetItem(r25, r64, r63) + r66 = r65 >= 0 :: signed + r67 = [r55, r13, r25] + r68 = load_address r67 + r69 = PyObject_Vectorcall(r15, r68, 3, 0) + keep_alive r55, r13, r25 + r70 = __main__.globals :: static + r71 = 'mypyc_attr' + r72 = CPyDict_GetItem(r70, r71) + r73 = box(bool, 0) + r74 = [r73] + r75 = load_address r74 + r76 = ('native_class',) + r77 = PyObject_Vectorcall(r72, r75, 0, r76) + keep_alive r73 + r78 = [r69] + r79 = load_address r78 + r80 = PyObject_Vectorcall(r77, r79, 1, 0) + keep_alive r69 + __main__.NonExt = r80 :: type + r81 = __main__.globals :: static + r82 = 'NonExt' + r83 = PyDict_SetItem(r81, r82, r80) + r84 = r83 >= 0 :: signed + r85 = __main__.NonExt :: type + return 1 diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 0c1557b9eed9..6a27ffa0ebd6 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -5814,3 +5814,79 @@ from native import Base Sub = type("Sub", (Base,), {}) s = Sub() assert s.method() == "a1" + +[case testClassVarSelfReference] +# ClassVar initializers that reference other ClassVars from the same class. +# In CPython, the class body executes as a function where earlier assignments +# are available to later ones. mypyc must replicate this by loading from the +# class being built (type object for ext classes, class dict for non-ext) +# instead of module globals. +from typing import ClassVar, Dict, Set + +class Ext: + A: ClassVar[Set[int]] = {1, 2, 3} + B: ClassVar[Set[int]] = {4, 5, 6} + C: ClassVar[Set[int]] = A | B + +class ExtChained: + X: ClassVar[Set[int]] = {1, 2} + Y: ClassVar[Set[int]] = X | {3} + Z: ClassVar[Set[int]] = Y | {4} + +class ExtDict: + BASE: ClassVar[Dict[str, int]] = {"a": 1, "b": 2} + EXTENDED: ClassVar[Dict[str, int]] = {**BASE, "c": 3} + +class ExtSub(Ext): + E: ClassVar[Set[int]] = {7, 8} + +[file driver.py] +from native import Ext, ExtChained, ExtDict, ExtSub + +assert Ext.A == {1, 2, 3} +assert Ext.B == {4, 5, 6} +assert Ext.C == {1, 2, 3, 4, 5, 6} + +assert ExtChained.X == {1, 2} +assert ExtChained.Y == {1, 2, 3} +assert ExtChained.Z == {1, 2, 3, 4} + +assert ExtDict.BASE == {"a": 1, "b": 2} +assert ExtDict.EXTENDED == {"a": 1, "b": 2, "c": 3} + +assert ExtSub.C == {1, 2, 3, 4, 5, 6} +assert ExtSub.E == {7, 8} + +[case testClassVarSelfReferenceNonExt] +# Same as testClassVarSelfReference but for non-extension classes. +from typing import ClassVar, Dict, Set +from mypy_extensions import mypyc_attr + +@mypyc_attr(native_class=False) +class NonExt: + A: ClassVar[Set[str]] = {"a", "b"} + B: ClassVar[Set[str]] = {"c"} + C: ClassVar[Set[str]] = A | B + +@mypyc_attr(native_class=False) +class NonExtDict: + BASE: ClassVar[Dict[str, int]] = {"x": 1} + EXTENDED: ClassVar[Dict[str, int]] = {**BASE, "y": 2} + +@mypyc_attr(native_class=False) +class NonExtChained: + X: ClassVar[Set[int]] = {10} + Y: ClassVar[Set[int]] = X | {20} + Z: ClassVar[Set[int]] = Y | {30} + +[file driver.py] +from native import NonExt, NonExtDict, NonExtChained + +assert NonExt.A == {"a", "b"} +assert NonExt.B == {"c"} +assert NonExt.C == {"a", "b", "c"} + +assert NonExtDict.BASE == {"x": 1} +assert NonExtDict.EXTENDED == {"x": 1, "y": 2} + +assert NonExtChained.Z == {10, 20, 30}