diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index 2ee09115..760cd123 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -10,6 +10,7 @@ import logging import os import re +import threading import typing from collections.abc import Iterable from operator import attrgetter @@ -422,7 +423,21 @@ def get_project_from_pypi( class BaseProvider(ExtrasProvider): + """Base class for Fromager's dependency resolver (resolvelib + extras). + + Subclasses implement ``find_candidates``, ``cache_key``, and + ``provider_description`` to list versions from PyPI, a version map, etc. + + Candidate lists are cached per package in one global dict, with a lock per + package so parallel work on different packages does not clash. + + ``find_matches`` keeps only versions that fit the requirements and + constraints, then picks newest first. + """ + resolver_cache: typing.ClassVar[ResolverCache] = {} + _cache_locks: typing.ClassVar[dict[str, threading.Lock]] = {} + _meta_lock: typing.ClassVar[threading.Lock] = threading.Lock() provider_description: typing.ClassVar[str] def __init__( @@ -465,16 +480,20 @@ def identify(self, requirement_or_candidate: Requirement | Candidate) -> str: @classmethod def clear_cache(cls, identifier: str | None = None) -> None: - """Clear global resolver cache + """Clear global resolver cache and associated per-identifier locks. ``None`` clears all caches, an ``identifier`` string clears the cache for an identifier. Raises :exc:`KeyError` for unknown identifiers. """ - if identifier is None: - cls.resolver_cache.clear() - else: - cls.resolver_cache.pop(canonicalize_name(identifier)) + with cls._meta_lock: + if identifier is None: + cls.resolver_cache.clear() + cls._cache_locks.clear() + else: + canon_name = canonicalize_name(identifier) + cls.resolver_cache.pop(canon_name) + cls._cache_locks.pop(canon_name, None) def get_extras_for( self, @@ -552,46 +571,85 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]: # return candidate.dependencies return [] - def _get_cached_candidates(self, identifier: str) -> list[Candidate]: - """Get list of cached candidates for identifier and provider + def _get_identifier_lock(self, identifier: str) -> threading.Lock: + """Get or create a per-identifier lock for thread-safe cache access. + + Uses a short-lived meta-lock to protect the lock dict itself. + The per-identifier lock ensures threads resolving different packages + proceed concurrently, while threads resolving the same package + wait for the first to populate the cache. + + The identifier is canonicalized to match the cache keys used by + ``_get_cached_candidates`` and ``clear_cache``. + """ + canonical = canonicalize_name(identifier) + with self._meta_lock: + if canonical not in self._cache_locks: + self._cache_locks[canonical] = threading.Lock() + return self._cache_locks[canonical] + + def _get_cached_candidates(self, identifier: str) -> list[Candidate] | None: + """Get a copy of cached candidates for identifier and provider. + + Returns None if no entry exists in the cache, or a copy of the cached + list (which may be empty). A copy is returned so callers cannot + accidentally corrupt the cache. + + Must be called under the per-identifier lock from _get_identifier_lock. + """ + cls = type(self) + provider_cache = cls.resolver_cache.get(identifier, {}) + candidate_cache = provider_cache.get((cls, self.cache_key)) + if candidate_cache is None: + return None + return list(candidate_cache) + + def _set_cached_candidates( + self, identifier: str, candidates: list[Candidate] + ) -> None: + """Store candidates in the cache for identifier and provider. - The method always returns a list. If the cache did not have an entry - before, a new empty list is stored in the cache and returned to the - caller. The caller can mutate the list in place to update the cache. + Must be called under the per-identifier lock from _get_identifier_lock. """ cls = type(self) provider_cache = cls.resolver_cache.setdefault(identifier, {}) - candidate_cache = provider_cache.setdefault((cls, self.cache_key), []) - return candidate_cache + provider_cache[(cls, self.cache_key)] = list(candidates) def _find_cached_candidates(self, identifier: str) -> Candidates: - """Find candidates with caching""" - cached_candidates: list[Candidate] = [] - if self.use_cache_candidates: + """Find candidates with caching. + + Uses a per-identifier lock so threads resolving different packages + proceed concurrently, while threads resolving the same package + wait for the first to populate the cache. + """ + if not self.use_cache_candidates: + candidates = list(self.find_candidates(identifier)) + logger.debug( + "%s: got %i unfiltered candidates, ignoring cache", + identifier, + len(candidates), + ) + return candidates + + lock = self._get_identifier_lock(identifier) + with lock: cached_candidates = self._get_cached_candidates(identifier) - if cached_candidates: + if cached_candidates is not None: logger.debug( "%s: use %i cached candidates", identifier, len(cached_candidates), ) return cached_candidates - candidates = list(self.find_candidates(identifier)) - if self.use_cache_candidates: - # mutate list object in-place - cached_candidates[:] = candidates + + candidates = list(self.find_candidates(identifier)) + self._set_cached_candidates(identifier, candidates) logger.debug( "%s: cache %i unfiltered candidates", identifier, len(candidates), ) - else: - logger.debug( - "%s: got %i unfiltered candidates, ignoring cache", - identifier, - len(candidates), - ) - return candidates + return candidates def _get_no_match_error_message( self, identifier: str, requirements: RequirementsMap diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 690a1288..6486721f 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,5 +1,7 @@ import datetime import re +import threading +import time import typing import pytest @@ -11,6 +13,7 @@ from fromager import constraints, resolver from fromager.__main__ import main as fromager +from fromager.candidate import Candidate _hydra_core_simple_response = """ @@ -58,7 +61,9 @@ @pytest.fixture(autouse=True) -def reset_cache() -> None: +def reset_cache() -> typing.Generator[None, None, None]: + resolver.BaseProvider.clear_cache() + yield resolver.BaseProvider.clear_cache() @@ -143,8 +148,10 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None: # fill the cache provider = pypi_hydra_resolver.provider assert provider.cache_key == "https://pypi.org/simple/" - req_cache = provider._get_cached_candidates(req.name) - assert req_cache == [] + lock = provider._get_identifier_lock(req.name) + with lock: + req_cache = provider._get_cached_candidates(req.name) + assert req_cache is None result = pypi_hydra_resolver.resolve([req]) candidate = result.mapping[req.name] @@ -153,10 +160,9 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None: resolver_cache = resolver.BaseProvider.resolver_cache assert req.name in resolver_cache assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name] - # mutated in place - assert provider._get_cached_candidates(req.name) is req_cache - assert len(provider._get_cached_candidates(req.name)) == 7 - assert len(req_cache) == 7 + # _get_cached_candidates returns a defensive copy, not the same object + with lock: + assert len(provider._get_cached_candidates(req.name)) == 7 def test_provider_cache_key_gitlab(gitlab_decile_resolver: typing.Any) -> None: @@ -1278,3 +1284,221 @@ def test_cli_package_resolver( assert "- PyPI versions: 1.2.2, 1.3.1+local, 1.3.2, 2.0.0a1" in result.stdout assert "- only wheels on PyPI: 1.3.1+local, 2.0.0a1" in result.stdout assert "- missing from Fromager: 1.3.1+local, 2.0.0a1" in result.stdout + + +def _make_candidate(name: str, version: str) -> Candidate: + """Create a minimal Candidate for testing.""" + return Candidate( + name=name, version=Version(version), url="https://example.com", is_sdist=False + ) + + +class _StubProvider(resolver.BaseProvider): + """Minimal BaseProvider subclass for cache tests.""" + + provider_description = "stub" + + @property + def cache_key(self) -> str: + return "stub-key" + + def find_candidates(self, identifier: str) -> list[Candidate]: + return [] + + +class _SlowProvider(resolver.BaseProvider): + """BaseProvider subclass whose find_candidates delegates to a callback. + + The callback receives the identifier and can sleep, record timestamps, + or count calls — whatever the test needs. + """ + + provider_description = "slow" + + def __init__( + self, + callback: typing.Callable[[str], list[Candidate]], + **kwargs: typing.Any, + ) -> None: + super().__init__(**kwargs) + self._callback = callback + + @property + def cache_key(self) -> str: + return "slow-key" + + def find_candidates(self, identifier: str) -> list[Candidate]: + return self._callback(identifier) + + +def test_get_cached_candidates_returns_defensive_copy() -> None: + """Mutating the list returned by _get_cached_candidates must not corrupt the cache.""" + provider = _StubProvider() + identifier = "test-pkg" + + # Seed the cache directly so the test doesn't depend on the aliasing bug + resolver.BaseProvider.resolver_cache[identifier] = { + (type(provider), provider.cache_key): [_make_candidate("test-pkg", "1.0.0")] + } + + # Get candidates and mutate the returned list (hold the lock per the + # documented contract, even though single-threaded) + lock = provider._get_identifier_lock(identifier) + with lock: + first = provider._get_cached_candidates(identifier) + assert first is not None + first.append(_make_candidate("test-pkg", "2.0.0")) + + # The cache should not reflect the caller's mutation + with lock: + second = provider._get_cached_candidates(identifier) + assert second is not None + assert len(second) == 1, ( + "_get_cached_candidates should return a defensive copy, " + "not a direct reference to the internal cache" + ) + assert second[0].version == Version("1.0.0") + + +def test_find_cached_candidates_thread_safe() -> None: + """Concurrent threads must not bypass the cache and call find_candidates multiple times.""" + call_count = 0 + call_count_lock = threading.Lock() + + def slow_find(identifier: str) -> list[Candidate]: + nonlocal call_count + with call_count_lock: + call_count += 1 + time.sleep(0.1) + return [_make_candidate(identifier, "1.0.0")] + + barrier = threading.Barrier(4) + + def resolve_in_thread(provider: _SlowProvider, ident: str) -> None: + barrier.wait(timeout=5) + list(provider._find_cached_candidates(ident)) + + providers = [_SlowProvider(callback=slow_find) for _ in range(4)] + threads = [ + threading.Thread( + target=resolve_in_thread, + args=(thread_provider, "shared-pkg"), + name=f"resolver-{i}", + ) + for i, thread_provider in enumerate(providers) + ] + + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not any(t.is_alive() for t in threads), "Threads did not complete in time" + + assert call_count == 1, ( + f"find_candidates() was called {call_count} times; expected 1. " + "Without thread-safe caching, multiple threads bypass the cache " + "and redundantly call find_candidates()." + ) + + +def test_find_cached_candidates_different_packages_concurrent() -> None: + """Threads resolving different packages must not block each other.""" + # Record start and end times so we can prove overlap without tight tolerances + call_spans: dict[str, tuple[float, float]] = {} + call_spans_lock = threading.Lock() + + def timed_find(identifier: str) -> list[Candidate]: + start = time.monotonic() + time.sleep(0.3) + end = time.monotonic() + with call_spans_lock: + call_spans[identifier] = (start, end) + return [_make_candidate(identifier, "1.0.0")] + + barrier = threading.Barrier(2) + + def resolve_in_thread(provider: _SlowProvider, ident: str) -> None: + barrier.wait(timeout=5) + list(provider._find_cached_candidates(ident)) + + providers = [_SlowProvider(callback=timed_find) for _ in range(2)] + threads = [ + threading.Thread( + target=resolve_in_thread, + args=(providers[i], f"pkg-{i}"), + name=f"resolver-{i}", + ) + for i in range(2) + ] + + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + assert not any(t.is_alive() for t in threads), "Threads did not complete in time" + + # Both packages should have been resolved + assert "pkg-0" in call_spans + assert "pkg-1" in call_spans + # Prove concurrency: each call must have started before the other finished. + # If a global lock serialized them, one would start only after the other ended. + start_0, end_0 = call_spans["pkg-0"] + start_1, end_1 = call_spans["pkg-1"] + assert start_0 < end_1 and start_1 < end_0, ( + "find_candidates for different packages should run concurrently, " + "not be serialized by a global lock" + ) + + +def test_clear_cache_cleans_up_locks() -> None: + """clear_cache() must remove per-identifier locks so they don't accumulate.""" + provider = _StubProvider() + + # Populate the cache and create a per-identifier lock + provider._find_cached_candidates("pkg-a") + provider._find_cached_candidates("pkg-b") + assert "pkg-a" in resolver.BaseProvider._cache_locks + assert "pkg-b" in resolver.BaseProvider._cache_locks + + # Clear everything + resolver.BaseProvider.clear_cache() + assert resolver.BaseProvider._cache_locks == {} + assert resolver.BaseProvider.resolver_cache == {} + + +def test_clear_cache_single_identifier_cleans_up_lock() -> None: + """clear_cache(identifier) must remove only the lock for that identifier.""" + provider = _StubProvider() + + provider._find_cached_candidates("pkg-a") + provider._find_cached_candidates("pkg-b") + + resolver.BaseProvider.clear_cache("pkg-a") + assert "pkg-a" not in resolver.BaseProvider._cache_locks + assert "pkg-b" in resolver.BaseProvider._cache_locks + + +def test_empty_candidate_list_is_cached() -> None: + """An empty find_candidates result must be cached, not re-fetched.""" + call_count = 0 + + def counting_find(identifier: str) -> list[Candidate]: + nonlocal call_count + call_count += 1 + return [] + + provider = _SlowProvider(callback=counting_find) + provider._find_cached_candidates("empty-pkg") + provider._find_cached_candidates("empty-pkg") + assert call_count == 1, ( + f"find_candidates() was called {call_count} times; expected 1. " + "Empty candidate lists must be treated as valid cache entries." + ) + + +def test_find_cached_candidates_cache_disabled() -> None: + """With use_resolver_cache=False, results must bypass the cache entirely.""" + provider = _StubProvider(use_resolver_cache=False) + result = list(provider._find_cached_candidates("uncached-pkg")) + assert result == [] + assert "uncached-pkg" not in resolver.BaseProvider.resolver_cache