Skip to content

fix: Autoregressive KV-cache RoPE initialization and HF kwargs support#4

Open
aniruddhaadak80 wants to merge 3 commits intokyegomez:mainfrom
aniruddhaadak80:fix-rope-generation
Open

fix: Autoregressive KV-cache RoPE initialization and HF kwargs support#4
aniruddhaadak80 wants to merge 3 commits intokyegomez:mainfrom
aniruddhaadak80:fix-rope-generation

Conversation

@aniruddhaadak80
Copy link
Copy Markdown

This PR fixes a critical mathematical bug in the handling of Rotary Positional Embeddings (RoPE) during autoregressive generation, and properly wires up Hugging Face's past_key_values KV caching!

1. Fix Autoregressive RoPE Frequency Indexing

  • Bug: In the previous decoding setup, the generation step passes 1 token into the system (because of kv_cache). The line freqs_cis[: x.shape[1]] in apply_rope() inadvertently sliced freqs_cis[: 1], permanently encoding position 0 into the query and key tensors of every single iteratively generated token. The result was that positional relationships were completely lost during generation.
  • Fix: Added start_pos tracking to the forward() loop and passed it correctly to freqs_cis. Modified apply_rope to avoid re-slicing because freqs_cis is now correctly sliced externally.

2. Properly wire HF past_key_values

  • Addition: Extracted start_pos implicitly when OpenMythosForCausalLM uses past_key_values and fed this through so Hugging Face ecosystem users (using .generate(...)) automatically get perfectly correctly cached sequences and rotary shifts.

Copilot AI review requested due to automatic review settings April 19, 2026 06:32
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to fix RoPE position indexing during autoregressive decoding with KV-cache, and add Hugging Face past_key_values support via a PreTrainedModel wrapper.

Changes:

  • Fix RoPE application during incremental decoding by introducing start_pos and slicing RoPE frequencies accordingly.
  • Switch attention implementations to F.scaled_dot_product_attention.
  • Add HF-compatible MythosConfig/OpenMythosForCausalLM, package exports, and a starter training script.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.

File Description
open_mythos/main.py Adds start_pos RoPE slicing, swaps attention to SDPA, introduces HF config/model wrapper.
open_mythos/__init__.py Exposes OpenMythos, MythosConfig, and OpenMythosForCausalLM from the package root.
train.py Adds a minimal training step example plus a MoE load-balancing loss helper.
pr_body.md Adds a PR body summary of architectural improvements.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread open_mythos/main.py
Comment on lines 986 to 990
x = self.embed(input_ids)
freqs_cis = (
self.freqs_cis_mla if self.cfg.attn_type == "mla" else self.freqs_cis
)[:T]
)[start_pos:start_pos + T]
mask = self._causal_mask(T, device) if T > 1 else None
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new start_pos RoPE offsetting is critical for correctness during incremental decoding, but the current test suite only checks a single cached forward pass. Please add a regression test that runs multi-step decoding (e.g., 3–5 tokens) and verifies cached token-by-token logits match an uncached full-sequence forward at each step (both GQA and MLA).

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/__init__.py
Comment on lines +1 to +2
from .main import OpenMythos, MythosConfig, OpenMythosForCausalLM

Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenMythosForCausalLM is exported unconditionally, but it’s not usable without transformers installed (and will currently fail at instantiation). Consider conditionally importing/adding it to __all__ only when transformers is available, or provide a stub that raises a clear ImportError on use.

Suggested change
from .main import OpenMythos, MythosConfig, OpenMythosForCausalLM
from .main import OpenMythos, MythosConfig
try:
from .main import OpenMythosForCausalLM
except ImportError as exc:
_open_mythos_for_causal_lm_import_error = exc
class OpenMythosForCausalLM:
def __init__(self, *args, **kwargs):
raise ImportError(
"OpenMythosForCausalLM requires the optional dependency "
"'transformers' to be installed."
) from _open_mythos_for_causal_lm_import_error

Copilot uses AI. Check for mistakes.
Comment thread pr_body.md
Comment on lines +1 to +8
This PR adds several high-confidence improvements to the OpenMythos architecture:

- **Flash Attention**: Upgraded GQAttention and MLAttention to use PyTorch native `F.scaled_dot_product_attention` for better memory complexity and speed.
- **Hugging Face compatibility**: Wrapped models with `PreTrainedModel` and `PretrainedConfig`.
- **Core Exports**: Added init file to the open_mythos directory for simpler imports.
- **Training loop and MoE Loss**: Added an initial training script providing load balancing logic for MoE models.

These changes provide immediate value and make the model faster and easier to use with existing AI ecosystems. No newline at end of file
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description in metadata focuses on the RoPE/KV-cache fix, but this PR body file lists additional major changes (Flash Attention via SDPA, HF wrapper/config, exports, training script) and doesn’t mention the RoPE regression. Please align the PR description/body so reviewers and release notes reflect the actual scope of changes.

Copilot uses AI. Check for mistakes.
Comment thread train.py
num_experts: Total number of routed experts
"""
if router_logits is None or top_k_indices is None:
return torch.tensor(0.0, device=router_logits.device)
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The None-guard will still crash because it tries to read router_logits.device when router_logits is None. Return a zero tensor on a known device (e.g., top_k_indices.device if available) or default to torch.device("cpu") when both are None.

Suggested change
return torch.tensor(0.0, device=router_logits.device)
device = (
router_logits.device if router_logits is not None
else top_k_indices.device if top_k_indices is not None
else torch.device("cpu")
)
return torch.tensor(0.0, device=device)

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/main.py
Comment on lines +18 to +21
PretrainedConfig = object
PreTrainedModel = object
CausalLMOutputWithPast = None

Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except ImportError fallback sets PreTrainedModel = object, but OpenMythosForCausalLM.__init__ calls super().__init__(config), which will raise TypeError if transformers isn’t installed. Consider conditionally defining/exporting OpenMythosForCausalLM only when transformers is available, or raise a clear ImportError in __init__ when the dependency is missing.

Suggested change
PretrainedConfig = object
PreTrainedModel = object
CausalLMOutputWithPast = None
class PretrainedConfig:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class PreTrainedModel(nn.Module):
config_class = PretrainedConfig
def __init__(self, config=None, *args, **kwargs):
super().__init__()
self.config = config
@dataclass
class CausalLMOutputWithPast:
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
past_key_values: Optional[tuple] = None
hidden_states: Optional[tuple] = None
attentions: Optional[tuple] = None

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/main.py
Comment on lines +270 to +274
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
is_causal=False
)
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaled_dot_product_attention is called with an explicit additive attn_mask and is_causal=False. If the mask is purely causal (triangular) and Q/K have equal sequence lengths (the common training/prompt case), using is_causal=True and omitting the float mask can enable faster SDPA backends (Flash/mem‑efficient) and avoid dtype upcasting from the current float32 -inf mask.

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/main.py
Comment on lines +1067 to +1078
# Determine start_pos automatically if KV cache is in use
start_pos = 0
if past_key_values is not None:
# We find the shape of the past values to determine length
first_key = next(iter(past_key_values.keys()))
if "k" in past_key_values[first_key]:
start_pos = past_key_values[first_key]["k"].shape[1]
elif "c_kv" in past_key_values[first_key]:
start_pos = past_key_values[first_key]["c_kv"].shape[1]

# We reuse past_key_values directly as kwargs kv_cache
kv_cache = past_key_values if past_key_values is not None else {}
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

past_key_values can be an empty cache object (e.g., {}) in some calling patterns; next(iter(past_key_values.keys())) will raise StopIteration. Consider treating an empty dict the same as None, and/or computing start_pos more defensively (e.g., from the max cached sequence length across entries).

Suggested change
# Determine start_pos automatically if KV cache is in use
start_pos = 0
if past_key_values is not None:
# We find the shape of the past values to determine length
first_key = next(iter(past_key_values.keys()))
if "k" in past_key_values[first_key]:
start_pos = past_key_values[first_key]["k"].shape[1]
elif "c_kv" in past_key_values[first_key]:
start_pos = past_key_values[first_key]["c_kv"].shape[1]
# We reuse past_key_values directly as kwargs kv_cache
kv_cache = past_key_values if past_key_values is not None else {}
# Determine start_pos automatically if KV cache is in use.
# Treat an empty cache the same as no cache, and compute the cached
# sequence length defensively from all available entries.
start_pos = 0
kv_cache = past_key_values if past_key_values else {}
if kv_cache:
cache_lengths = []
for layer_cache in kv_cache.values():
if not isinstance(layer_cache, dict):
continue
if "k" in layer_cache:
cache_lengths.append(layer_cache["k"].shape[1])
elif "c_kv" in layer_cache:
cache_lengths.append(layer_cache["c_kv"].shape[1])
if cache_lengths:
start_pos = max(cache_lengths)

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/main.py
Comment on lines +1059 to +1066
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
n_loops: Optional[int] = None,
past_key_values: Optional[dict] = None,
**kwargs
) -> CausalLMOutputWithPast:
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For full Hugging Face .generate() compatibility, returning a custom dict as past_key_values typically requires implementing _reorder_cache (beam search) and often use_cache handling. As-is, greedy decoding may work, but beam search / sampling with beams is likely to break when GenerationMixin tries to reorder caches.

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/main.py
Comment on lines +27 to 28
class MythosConfig(PretrainedConfig):
"""
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After switching MythosConfig away from @dataclass, the from dataclasses import dataclass import at the top of this file appears to be unused and should be removed to avoid lint/static-analysis failures.

Copilot uses AI. Check for mistakes.
Comment thread open_mythos/main.py
Comment on lines +410 to +413
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
is_causal=False
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same SDPA concern as above: passing a float attn_mask with is_causal=False can force a slower kernel and/or dtype promotion. If the mask is strictly causal and Q/K lengths match, prefer is_causal=True and no explicit mask; otherwise ensure the mask dtype matches q/k to avoid unintended upcasting.

Suggested change
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
is_causal=False
use_is_causal = mask is None and T == S
attn_mask = None if mask is None else mask.to(dtype=q.dtype)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
is_causal=use_is_causal

Copilot uses AI. Check for mistakes.
@qodo-ai-reviewer
Copy link
Copy Markdown

Hi, OpenMythosForCausalLM.forward() assumes past_key_values is a dict (uses .keys() and string keys), but HF generate() uses a tuple/structured cache; this will crash and prevents HF caching from working.

Severity: action required | Category: correctness

How to fix: Implement HF cache conversion

Agent prompt to fix - you can give this to your LLM of choice:

Issue description

OpenMythosForCausalLM.forward() treats past_key_values as the model’s internal kv_cache dict. Hugging Face generation expects/returns a different cache structure; the current implementation will crash or silently break caching.

Issue Context

Internally, attention layers mutate a dict keyed by string cache keys. HF generate() expects past_key_values to be in HF’s cache format and will feed it back on subsequent steps.

Fix Focus Areas

  • open_mythos/main.py[1048-1098]

What to implement

  • Accept HF past_key_values in the format HF passes (likely tuple-like per layer) and convert it into the internal kv_cache dict before calling self.model(...).
  • Convert the internal kv_cache back into HF’s expected past_key_values structure in the returned CausalLMOutputWithPast.
  • Alternatively, if you don’t want to support HF’s native cache format, explicitly disable caching in HF (use_cache=False) and raise a clear error if past_key_values is provided.

We noticed a couple of other issues in this PR as well - happy to share if helpful.


Found by Qodo. Free code review for open-source maintainers.

@qodo-ai-reviewer
Copy link
Copy Markdown

Hi, apply_rope() no longer slices freqs_cis to the current sequence length T, so callers that pass precomputed (max_len, dim//2) freqs (as the tests currently do) will hit a tensor shape/broadcast error in xc * freqs_cis.

Severity: action required | Category: correctness

How to fix: Restore slicing or update callers

Agent prompt to fix - you can give this to your LLM of choice:

Issue description

apply_rope() stopped slicing freqs_cis to T, but multiple repo call sites pass full precomputed RoPE freqs. This causes shape mismatch at runtime.

Issue Context

You changed the contract so freqs_cis must be pre-sliced by the caller. Internal model code now slices in OpenMythos.forward, but tests and standalone attention usage still pass full freqs.

Fix Focus Areas

  • open_mythos/main.py[175-193]
  • test_main.py[97-114]
  • test_main.py[240-260]

Suggested fix

Pick one:

  1. Make apply_rope() robust again by slicing: freqs_cis = freqs_cis[:x.shape[1]] (and keep OpenMythos.forward slicing too, or remove one for consistency), OR
  2. Update all external callers/tests to pass freqs[:T] (and update any docs/examples accordingly). Ensure attention unit tests slice freqs to T in setup.

Found by Qodo code review. FYI, Qodo is free for open-source.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants