Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 275 additions & 24 deletions toolshed/check_cython_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
"""
Tool to check for Cython ABI changes in a given package.

There are different types of ABI changes, only one of which is covered by this tool:
Cython must be installed in your venv to run this script.

There are different types of ABI changes, some of which are covered by this tool:

- cdef function signatures (capsule strings) — covered here
- cdef class struct size (tp_basicsize) — not covered
- cdef class vtable layout / method reordering — not covered, and this one fails as silent UB rather than an import-time error
- Fused specialization ordering — partially covered (reorders manifest as capsule-name deltas, but the mapping is non-obvious)
- cdef class struct size (tp_basicsize) — covered here
- cdef struct / ctypedef struct field layout — covered here (via .pxd parsing)
- cdef class vtable layout / method reordering — not covered, and this one fails
as silent UB rather than an import-time error
- Fused specialization ordering — partially covered (reorders manifest as
capsule-name deltas, but the mapping is non-obvious)

The workflow is basically:

Expand All @@ -21,22 +26,28 @@
package is installed), where `package_name` is the import path to the package,
e.g. `cuda.bindings`:

python check_cython_abi.py generate <package_name> <dir>
python check_cython_abi.py generate <package_name> <dir>

3) Checkout a version with the changes to be tested, and build and install.

4) Check the ABI against the previously generated files by running:

python check_cython_abi.py check <package_name> <dir>
python check_cython_abi.py check <package_name> <dir>
"""

import ctypes
import importlib
import json
import sys
import sysconfig
from io import StringIO
from pathlib import Path

from Cython.Compiler import Parsing
from Cython.Compiler.Scanning import FileSourceDescriptor, PyrexScanner
from Cython.Compiler.Symtab import ModuleScope
from Cython.Compiler.TreeFragment import StringParseContext

EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX")
ABI_SUFFIX = ".abi.json"

Expand Down Expand Up @@ -66,12 +77,12 @@ def import_from_path(root_package: str, root_dir: Path, path: Path) -> object:


def so_path_to_abi_path(so_path: Path, build_dir: Path, abi_dir: Path) -> Path:
abi_name = short_stem(so_path.name) + ABI_SUFFIX
abi_name = f"{short_stem(so_path.name)}{ABI_SUFFIX}"
return abi_dir / so_path.parent.relative_to(build_dir) / abi_name


def abi_path_to_so_path(abi_path: Path, build_dir: Path, abi_dir: Path) -> Path:
so_name = short_stem(abi_path.name) + EXT_SUFFIX
so_name = f"{short_stem(abi_path.name)}{EXT_SUFFIX}"
return build_dir / abi_path.parent.relative_to(abi_dir) / so_name


Expand All @@ -80,16 +91,244 @@ def is_cython_module(module: object) -> bool:
return hasattr(module, "__pyx_capi__")


def module_to_json(module: object) -> dict:
"""
Converts extracts information about a Cython-compiled .so into JSON-serializable information.
######################################################################################
# STRUCTS


def get_cdef_classes(module: object) -> dict:
"""Extract cdef class (extension type) basicsize from a compiled Cython module."""
result = {}
module_name = module.__name__
for name in sorted(dir(module)):
obj = getattr(module, name, None)
if isinstance(obj, type) and getattr(obj, "__module__", None) == module_name and hasattr(obj, "__basicsize__"):
result[name] = {"basicsize": obj.__basicsize__}
return result


def _format_base_type_name(bt: object) -> str:
"""Format a Cython base type AST node into a type name string."""
cls = type(bt).__name__
if cls == "CSimpleBaseTypeNode":
return bt.name
if cls == "CComplexBaseTypeNode":
inner = _format_base_type_name(bt.base_type)
return _unwrap_declarator(inner, bt.declarator)[0]
return cls


def _unwrap_declarator(type_str: str, decl: object) -> tuple[str, str]:
"""Unwrap nested Cython declarator nodes to get (type_string, field_name)."""
cls = type(decl).__name__
if cls == "CNameDeclaratorNode":
return type_str, decl.name
if cls == "CPtrDeclaratorNode":
return _unwrap_declarator(f"{type_str}*", decl.base)
if cls == "CReferenceDeclaratorNode":
return _unwrap_declarator(f"{type_str}&", decl.base)
if cls == "CArrayDeclaratorNode":
dim = getattr(decl, "dimension", None)
size = getattr(dim, "value", "") if dim is not None else ""
return _unwrap_declarator(f"{type_str}[{size}]", decl.base)
return type_str, ""


def _extract_fields_from_cvardef(node: object) -> list:
"""Extract [type, name] pairs from a CVarDefNode."""
results = []
for d in node.declarators:
type_str, name = _unwrap_declarator(_format_base_type_name(node.base_type), d)
if name:
results.append([type_str, name])
return results


def _collect_cvardef_fields(node: object) -> list:
"""Recursively collect CVarDefNode fields, skipping nested struct/class/func defs."""
fields = []
if type(node).__name__ == "CVarDefNode":
fields.extend(_extract_fields_from_cvardef(node))
skip = ("CStructOrUnionDefNode", "CClassDefNode", "CFuncDefNode")
for attr_name in getattr(node, "child_attrs", []):
child = getattr(node, attr_name, None)
if child is None:
continue
if isinstance(child, list):
for item in child:
if hasattr(item, "child_attrs") and type(item).__name__ not in skip:
fields.extend(_collect_cvardef_fields(item))
elif hasattr(child, "child_attrs") and type(child).__name__ not in skip:
fields.extend(_collect_cvardef_fields(child))
return fields


def _collect_structs_from_tree(node: object) -> dict:
"""Walk a Cython AST and collect struct/class field definitions."""
result = {}
cls = type(node).__name__

if cls == "CStructOrUnionDefNode":
fields = []
for attr in node.attributes:
if type(attr).__name__ == "CVarDefNode":
fields.extend(_extract_fields_from_cvardef(attr))
if fields:
result[node.name] = {"fields": fields}

elif cls == "CClassDefNode":
fields = _collect_cvardef_fields(node.body)
if fields:
result[node.class_name] = {"fields": fields}

for attr_name in getattr(node, "child_attrs", []):
child = getattr(node, attr_name, None)
if child is None:
continue
if isinstance(child, list):
for item in child:
if hasattr(item, "child_attrs"):
result.update(_collect_structs_from_tree(item))
elif hasattr(child, "child_attrs"):
result.update(_collect_structs_from_tree(child))

return result


class _PxdParseContext(StringParseContext):
"""Parse context that resolves includes via real paths and ignores unknown cimports."""

def find_module(
self,
module_name,
from_module=None, # noqa: ARG002
pos=None, # noqa: ARG002
need_pxd=1, # noqa: ARG002
absolute_fallback=True, # noqa: ARG002
relative_import=False, # noqa: ARG002
):
return ModuleScope(module_name, parent_module=None, context=self)


def parse_pxd_structs(pxd_path: Path) -> dict:
"""Parse struct and cdef class field definitions from a .pxd file.

Uses Cython's own parser (in .pxd mode) for reliable extraction.
cimport lines in the top-level file are stripped since they are
unresolvable without the full compilation context; included files
are handled via a lenient context that returns dummy scopes.

Returns a dict mapping struct/class name to {"fields": [[type, name], ...]}.
"""
# Sort the dictionary by keys to make diffs in the JSON files smaller
pyx_capi = module.__pyx_capi__
text = pxd_path.read_text(encoding="utf-8")

# Strip cimport lines (unresolvable without full compilation context)
lines = text.splitlines()
cleaned = "\n".join("" if (" cimport " in ln or ln.lstrip().startswith("cimport ")) else ln for ln in lines)

name = pxd_path.stem
context = _PxdParseContext(name, include_directories=[str(pxd_path.parent)])
code_source = FileSourceDescriptor(str(pxd_path))
scope = context.find_module(name, pos=(code_source, 1, 0), need_pxd=False)

scanner = PyrexScanner(
StringIO(cleaned),
code_source,
source_encoding="UTF-8",
scope=scope,
context=context,
initial_pos=(code_source, 1, 0),
)
tree = Parsing.p_module(scanner, pxd=1, full_module_name=name)
tree.scope = scope

return _collect_structs_from_tree(tree)


def get_structs(module: object) -> dict:
# Extract cdef class basicsize from compiled module (primary)
structs = get_cdef_classes(module)
so_path = Path(module.__file__)

# Parse neighboring .pxd file for struct/class field layout (fallback complement)
if so_path is not None:
pxd_path = so_path.parent / f"{short_stem(so_path.name)}.pxd"
if pxd_path.is_file():
pxd_structs = parse_pxd_structs(pxd_path)
for name, info in pxd_structs.items():
if name in structs:
structs[name].update(info)
else:
structs[name] = info

return dict(sorted(structs.items()))


def _report_field_changes(name: str, expected_fields: list, found_fields: list) -> None:
"""Print detailed field-level differences for a struct."""
expected_dict = {f[1]: f[0] for f in expected_fields}
found_dict = {f[1]: f[0] for f in found_fields}

for field_name, field_type in expected_dict.items():
if field_name not in found_dict:
print(f" Struct {name}: removed field '{field_name}'")
elif found_dict[field_name] != field_type:
print(
f" Struct {name}: field '{field_name}' type changed from '{field_type}' to '{found_dict[field_name]}'"
)
for field_name in found_dict:
if field_name not in expected_dict:
print(f" Struct {name}: added field '{field_name}'")

expected_common = [f[1] for f in expected_fields if f[1] in found_dict]
found_common = [f[1] for f in found_fields if f[1] in expected_dict]
if expected_common != found_common:
print(f" Struct {name}: fields were reordered")


def check_structs(expected: dict, found: dict) -> tuple[bool, bool]:
has_errors = False
has_allowed_changes = False

for name, expected_info in expected.items():
if name not in found:
print(f" Missing struct/class: {name}")
has_errors = True
continue
found_info = found[name]

return {
"functions": {k: get_capsule_name(pyx_capi[k]) for k in sorted(pyx_capi.keys())},
}
if "basicsize" in expected_info:
if "basicsize" not in found_info:
print(f" Struct {name}: basicsize no longer available")
has_errors = True
elif found_info["basicsize"] != expected_info["basicsize"]:
print(
f" Struct {name}: basicsize changed from {expected_info['basicsize']} to {found_info['basicsize']}"
)
has_errors = True

if "fields" in expected_info:
if "fields" not in found_info:
print(f" Struct {name}: field information no longer available")
has_errors = True
elif found_info["fields"] != expected_info["fields"]:
_report_field_changes(name, expected_info["fields"], found_info["fields"])
has_errors = True

for name in found:
if name not in expected:
print(f" Added struct/class: {name}")
has_allowed_changes = True

return has_errors, has_allowed_changes


######################################################################################
# FUNCTIONS


def get_functions(module: object) -> dict:
pyx_capi = module.__pyx_capi__
return {k: get_capsule_name(pyx_capi[k]) for k in sorted(pyx_capi.keys())}


def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bool, bool]:
Expand All @@ -109,17 +348,29 @@ def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bo
return has_errors, has_allowed_changes


######################################################################################
# MAIN


def compare(expected: dict, found: dict) -> tuple[bool, bool]:
has_errors = False
has_allowed_changes = False

errors, allowed_changes = check_functions(expected["functions"], found["functions"])
has_errors |= errors
has_allowed_changes |= allowed_changes
for func, name in [(check_functions, "functions"), (check_structs, "structs")]:
errors, allowed_changes = func(expected[name], found[name])
has_errors |= errors
has_allowed_changes |= allowed_changes

return has_errors, has_allowed_changes


def module_to_json(module: object) -> dict:
"""
Extracts information about a Cython-compiled .so into JSON-serializable information.
"""
return {"functions": get_functions(module), "structs": get_structs(module)}


def check(package: str, abi_dir: Path) -> bool:
build_dir = get_package_path(package)

Expand Down Expand Up @@ -168,7 +419,7 @@ def check(package: str, abi_dir: Path) -> bool:
return False


def regenerate(package: str, abi_dir: Path) -> bool:
def generate(package: str, abi_dir: Path) -> bool:
if abi_dir.is_dir():
print(f"ABI directory {abi_dir} already exists. Please remove it before regenerating.")
return True
Expand Down Expand Up @@ -199,10 +450,10 @@ def regenerate(package: str, abi_dir: Path) -> bool:

subparsers = parser.add_subparsers()

regen_parser = subparsers.add_parser("generate", help="Regenerate the ABI files")
regen_parser.set_defaults(func=regenerate)
regen_parser.add_argument("package", help="Python package to collect data from")
regen_parser.add_argument("dir", help="Output directory to save data to")
gen_parser = subparsers.add_parser("generate", help="Regenerate the ABI files")
gen_parser.set_defaults(func=generate)
gen_parser.add_argument("package", help="Python package to collect data from")
gen_parser.add_argument("dir", help="Output directory to save data to")

check_parser = subparsers.add_parser("check", help="Check the API against existing ABI files")
check_parser.set_defaults(func=check)
Expand Down
Loading