diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index ee186dee4..5a5698b39 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -585,8 +585,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded -class Apriel2Attention(nn.Module): - """Apriel2 attention layer with rotary embeddings and GQA support.""" +class Apriel2Attention(nn.Module, AttentionLayerBase): + """Apriel2 attention layer with rotary embeddings and GQA support. + + Inherits from AttentionLayerBase to ensure vLLM uses our get_kv_cache_spec() + which returns the unified block size needed for hybrid models. + """ def __init__( self, @@ -598,6 +602,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + self.prefix = prefix self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size @@ -678,6 +683,16 @@ def get_layer_bias(layer_name: str) -> bool: prefix=f"{prefix}.attn", ) + # Override the internal Attention's get_kv_cache_spec to use our unified block size. + # The internal Attention stays registered in static_forward_context (needed for forward + # pass lookup), but when vLLM collects cache specs, it will get our unified block size. + wrapper_self = self # Capture for closure + self.attn.get_kv_cache_spec = lambda vllm_config: wrapper_self.get_kv_cache_spec(vllm_config) + + def get_attn_backend(self) -> type[AttentionBackend]: + """Delegate to internal Attention's backend.""" + return self.attn.get_attn_backend() + def forward( self, hidden_states: torch.Tensor, @@ -1810,7 +1825,7 @@ def forward( beta = self.b_proj(hidden_states)[0].float().sigmoid() g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] - g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias) + g1 = fused_kda_gate(g1, self.A_log.float(), self.head_dim, g_bias=self.dt_bias) beta = beta.unsqueeze(0) g1 = g1.unsqueeze(0) @@ -2957,7 +2972,38 @@ def _patch_worker_for_placement_switching(): def _get_layer_placements(self) -> dict[int, str]: return self.get_model().get_layer_placements() + def _clear_kv_cache(self) -> None: + """Clear all KV cache tensors to prevent stale data after placement switch. + + When mixer placement changes (e.g., layer 0 switches from KDA to attention), + the KV cache may contain data written by a different mixer type. Since different + mixers use incompatible cache formats, we must clear the cache to prevent NaN + errors from reading corrupted data. + """ + model_runner = getattr(self, "model_runner", None) + if model_runner is None: + return + + kv_caches = getattr(model_runner, "kv_caches", []) + for cache_item in kv_caches: + if cache_item is None: + continue + # KV cache items can be either: + # - torch.Tensor for attention layers + # - list[torch.Tensor] for state-based layers (KDA, Mamba) + if isinstance(cache_item, list): + for tensor in cache_item: + if tensor is not None: + tensor.zero_() + else: + cache_item.zero_() + + logger.info("Cleared KV cache tensors for placement switch") + def _set_layer_placements(self, placement: list[str]) -> dict[int, str]: + # Clear KV cache BEFORE changing placement to prevent reading stale data + # written by a different mixer type (which could cause NaN errors) + _clear_kv_cache(self) return self.get_model().set_layer_placements(placement) def _get_mixer_names(self) -> tuple[str, ...]: @@ -2966,6 +3012,7 @@ def _get_mixer_names(self) -> tuple[str, ...]: Worker.get_layer_placements = _get_layer_placements Worker.set_layer_placements = _set_layer_placements Worker.get_mixer_names = _get_mixer_names + Worker.clear_kv_cache = _clear_kv_cache _patch_worker_for_placement_switching()