fix: Autoregressive KV-cache RoPE initialization and HF kwargs support#4
fix: Autoregressive KV-cache RoPE initialization and HF kwargs support#4aniruddhaadak80 wants to merge 3 commits intokyegomez:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to fix RoPE position indexing during autoregressive decoding with KV-cache, and add Hugging Face past_key_values support via a PreTrainedModel wrapper.
Changes:
- Fix RoPE application during incremental decoding by introducing
start_posand slicing RoPE frequencies accordingly. - Switch attention implementations to
F.scaled_dot_product_attention. - Add HF-compatible
MythosConfig/OpenMythosForCausalLM, package exports, and a starter training script.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
open_mythos/main.py |
Adds start_pos RoPE slicing, swaps attention to SDPA, introduces HF config/model wrapper. |
open_mythos/__init__.py |
Exposes OpenMythos, MythosConfig, and OpenMythosForCausalLM from the package root. |
train.py |
Adds a minimal training step example plus a MoE load-balancing loss helper. |
pr_body.md |
Adds a PR body summary of architectural improvements. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| x = self.embed(input_ids) | ||
| freqs_cis = ( | ||
| self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis | ||
| )[:T] | ||
| )[start_pos:start_pos + T] | ||
| mask = self._causal_mask(T, device) if T > 1 else None |
There was a problem hiding this comment.
The new start_pos RoPE offsetting is critical for correctness during incremental decoding, but the current test suite only checks a single cached forward pass. Please add a regression test that runs multi-step decoding (e.g., 3–5 tokens) and verifies cached token-by-token logits match an uncached full-sequence forward at each step (both GQA and MLA).
| from .main import OpenMythos, MythosConfig, OpenMythosForCausalLM | ||
|
|
There was a problem hiding this comment.
OpenMythosForCausalLM is exported unconditionally, but it’s not usable without transformers installed (and will currently fail at instantiation). Consider conditionally importing/adding it to __all__ only when transformers is available, or provide a stub that raises a clear ImportError on use.
| from .main import OpenMythos, MythosConfig, OpenMythosForCausalLM | |
| from .main import OpenMythos, MythosConfig | |
| try: | |
| from .main import OpenMythosForCausalLM | |
| except ImportError as exc: | |
| _open_mythos_for_causal_lm_import_error = exc | |
| class OpenMythosForCausalLM: | |
| def __init__(self, *args, **kwargs): | |
| raise ImportError( | |
| "OpenMythosForCausalLM requires the optional dependency " | |
| "'transformers' to be installed." | |
| ) from _open_mythos_for_causal_lm_import_error |
| This PR adds several high-confidence improvements to the OpenMythos architecture: | ||
|
|
||
| - **Flash Attention**: Upgraded GQAttention and MLAttention to use PyTorch native `F.scaled_dot_product_attention` for better memory complexity and speed. | ||
| - **Hugging Face compatibility**: Wrapped models with `PreTrainedModel` and `PretrainedConfig`. | ||
| - **Core Exports**: Added init file to the open_mythos directory for simpler imports. | ||
| - **Training loop and MoE Loss**: Added an initial training script providing load balancing logic for MoE models. | ||
|
|
||
| These changes provide immediate value and make the model faster and easier to use with existing AI ecosystems. No newline at end of file |
There was a problem hiding this comment.
The PR description in metadata focuses on the RoPE/KV-cache fix, but this PR body file lists additional major changes (Flash Attention via SDPA, HF wrapper/config, exports, training script) and doesn’t mention the RoPE regression. Please align the PR description/body so reviewers and release notes reflect the actual scope of changes.
| num_experts: Total number of routed experts | ||
| """ | ||
| if router_logits is None or top_k_indices is None: | ||
| return torch.tensor(0.0, device=router_logits.device) |
There was a problem hiding this comment.
The None-guard will still crash because it tries to read router_logits.device when router_logits is None. Return a zero tensor on a known device (e.g., top_k_indices.device if available) or default to torch.device("cpu") when both are None.
| return torch.tensor(0.0, device=router_logits.device) | |
| device = ( | |
| router_logits.device if router_logits is not None | |
| else top_k_indices.device if top_k_indices is not None | |
| else torch.device("cpu") | |
| ) | |
| return torch.tensor(0.0, device=device) |
| PretrainedConfig = object | ||
| PreTrainedModel = object | ||
| CausalLMOutputWithPast = None | ||
|
|
There was a problem hiding this comment.
The except ImportError fallback sets PreTrainedModel = object, but OpenMythosForCausalLM.__init__ calls super().__init__(config), which will raise TypeError if transformers isn’t installed. Consider conditionally defining/exporting OpenMythosForCausalLM only when transformers is available, or raise a clear ImportError in __init__ when the dependency is missing.
| PretrainedConfig = object | |
| PreTrainedModel = object | |
| CausalLMOutputWithPast = None | |
| class PretrainedConfig: | |
| def __init__(self, **kwargs): | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| class PreTrainedModel(nn.Module): | |
| config_class = PretrainedConfig | |
| def __init__(self, config=None, *args, **kwargs): | |
| super().__init__() | |
| self.config = config | |
| @dataclass | |
| class CausalLMOutputWithPast: | |
| loss: Optional[torch.Tensor] = None | |
| logits: Optional[torch.Tensor] = None | |
| past_key_values: Optional[tuple] = None | |
| hidden_states: Optional[tuple] = None | |
| attentions: Optional[tuple] = None |
| out = F.scaled_dot_product_attention( | ||
| q, k, v, | ||
| attn_mask=mask, | ||
| is_causal=False | ||
| ) |
There was a problem hiding this comment.
scaled_dot_product_attention is called with an explicit additive attn_mask and is_causal=False. If the mask is purely causal (triangular) and Q/K have equal sequence lengths (the common training/prompt case), using is_causal=True and omitting the float mask can enable faster SDPA backends (Flash/mem‑efficient) and avoid dtype upcasting from the current float32 -inf mask.
| # Determine start_pos automatically if KV cache is in use | ||
| start_pos = 0 | ||
| if past_key_values is not None: | ||
| # We find the shape of the past values to determine length | ||
| first_key = next(iter(past_key_values.keys())) | ||
| if "k" in past_key_values[first_key]: | ||
| start_pos = past_key_values[first_key]["k"].shape[1] | ||
| elif "c_kv" in past_key_values[first_key]: | ||
| start_pos = past_key_values[first_key]["c_kv"].shape[1] | ||
|
|
||
| # We reuse past_key_values directly as kwargs kv_cache | ||
| kv_cache = past_key_values if past_key_values is not None else {} |
There was a problem hiding this comment.
past_key_values can be an empty cache object (e.g., {}) in some calling patterns; next(iter(past_key_values.keys())) will raise StopIteration. Consider treating an empty dict the same as None, and/or computing start_pos more defensively (e.g., from the max cached sequence length across entries).
| # Determine start_pos automatically if KV cache is in use | |
| start_pos = 0 | |
| if past_key_values is not None: | |
| # We find the shape of the past values to determine length | |
| first_key = next(iter(past_key_values.keys())) | |
| if "k" in past_key_values[first_key]: | |
| start_pos = past_key_values[first_key]["k"].shape[1] | |
| elif "c_kv" in past_key_values[first_key]: | |
| start_pos = past_key_values[first_key]["c_kv"].shape[1] | |
| # We reuse past_key_values directly as kwargs kv_cache | |
| kv_cache = past_key_values if past_key_values is not None else {} | |
| # Determine start_pos automatically if KV cache is in use. | |
| # Treat an empty cache the same as no cache, and compute the cached | |
| # sequence length defensively from all available entries. | |
| start_pos = 0 | |
| kv_cache = past_key_values if past_key_values else {} | |
| if kv_cache: | |
| cache_lengths = [] | |
| for layer_cache in kv_cache.values(): | |
| if not isinstance(layer_cache, dict): | |
| continue | |
| if "k" in layer_cache: | |
| cache_lengths.append(layer_cache["k"].shape[1]) | |
| elif "c_kv" in layer_cache: | |
| cache_lengths.append(layer_cache["c_kv"].shape[1]) | |
| if cache_lengths: | |
| start_pos = max(cache_lengths) | |
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| labels: Optional[torch.Tensor] = None, | ||
| n_loops: Optional[int] = None, | ||
| past_key_values: Optional[dict] = None, | ||
| **kwargs | ||
| ) -> CausalLMOutputWithPast: |
There was a problem hiding this comment.
For full Hugging Face .generate() compatibility, returning a custom dict as past_key_values typically requires implementing _reorder_cache (beam search) and often use_cache handling. As-is, greedy decoding may work, but beam search / sampling with beams is likely to break when GenerationMixin tries to reorder caches.
| class MythosConfig(PretrainedConfig): | ||
| """ |
There was a problem hiding this comment.
After switching MythosConfig away from @dataclass, the from dataclasses import dataclass import at the top of this file appears to be unused and should be removed to avoid lint/static-analysis failures.
| out = F.scaled_dot_product_attention( | ||
| q, k, v, | ||
| attn_mask=mask, | ||
| is_causal=False |
There was a problem hiding this comment.
Same SDPA concern as above: passing a float attn_mask with is_causal=False can force a slower kernel and/or dtype promotion. If the mask is strictly causal and Q/K lengths match, prefer is_causal=True and no explicit mask; otherwise ensure the mask dtype matches q/k to avoid unintended upcasting.
| out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=mask, | |
| is_causal=False | |
| use_is_causal = mask is None and T == S | |
| attn_mask = None if mask is None else mask.to(dtype=q.dtype) | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=attn_mask, | |
| is_causal=use_is_causal |
|
Hi, Severity: action required | Category: correctness How to fix: Implement HF cache conversion Agent prompt to fix - you can give this to your LLM of choice:
We noticed a couple of other issues in this PR as well - happy to share if helpful. Found by Qodo. Free code review for open-source maintainers. |
|
Hi, Severity: action required | Category: correctness How to fix: Restore slicing or update callers Agent prompt to fix - you can give this to your LLM of choice:
Found by Qodo code review. FYI, Qodo is free for open-source. |
This PR fixes a critical mathematical bug in the handling of Rotary Positional Embeddings (RoPE) during autoregressive generation, and properly wires up Hugging Face's
past_key_valuesKV caching!1. Fix Autoregressive RoPE Frequency Indexing
1token into the system (because ofkv_cache). The linefreqs_cis[: x.shape[1]]inapply_rope()inadvertently slicedfreqs_cis[: 1], permanently encoding position 0 into the query and key tensors of every single iteratively generated token. The result was that positional relationships were completely lost during generation.start_postracking to theforward()loop and passed it correctly tofreqs_cis. Modifiedapply_ropeto avoid re-slicing becausefreqs_cisis now correctly sliced externally.2. Properly wire HF
past_key_valuesstart_posimplicitly whenOpenMythosForCausalLMusespast_key_valuesand fed this through so Hugging Face ecosystem users (using.generate(...)) automatically get perfectly correctly cached sequences and rotary shifts.