Conversation
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
a0b201d to
cc30a90
Compare
cc30a90 to
2e17497
Compare
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
3737e2a to
37e70bc
Compare
Signed-off-by: Mustafa Eyceoz <[email protected]>
RobotSail
left a comment
There was a problem hiding this comment.
Great work on the GPT-OSS integration! The MXFP4 quantization/dequantization implementation, router parameter freezing, auxiliary loss support, and batch collation refactor represent solid technical execution.
Before we can merge this PR, there are several issues we need to address:
- Data processing seems like it may have trouble correctly identifying GPT-OSS models in edge cases
- The batch collator looks like it could be dropping minibatches and not accumulating sample/token counts correctly
- kernels dependency is missing from requirements
- In the training loop, looks like we're reducing floats instead of ints in several places
- Quite a few functions and variables could use more descriptive names
- I noticed several places where our logic can be simplified, especially in the GPT-OSS model saving workflow
I like the batch collator refactor, but I'm worried that using a running estimate for batch_num_loss_counted_tokens may cause problems. Since we're now using a running estimate for the total number of loss counted tokens in the entire minibatch, the gradient
signal the model receives will be biased toward the first few microbatches, which can heavily overshoot or undershoot the estimate compared to later microbatches. This could be a big problem when training with large EBS (3840) and on datasets with rare
samples you want to train on (~10k/370k).
We may want to add code in the future that reads the number of loss counted tokens in each batch (validation loss, logging, etc.). In these cases, the value won't be stable until we hit the number of accumulation steps.
One way to solve this issue, while also fixing the dual-use of regular and batch sampler during distributed sampler fallback, would be updating how we use the MultipackSampler to have it behave like a regular sampler that generates samples ahead of time on a
budget, then collates them together at the end. This way, each time you get a batch from the data loader, you know which tokens will be loaded.
It would also be good to add some behavior-driven tests for the sampling so we can have confidence in the refactor.
Please address the issues listed above and consider the suggested approach for the batch collator. The core implementation is solid - these changes will make it production-ready.
| test_tokens = ["<|start|>", "<|channel|>", "<|message|>"] | ||
| for token in test_tokens: | ||
| # If any of these tokens can't be encoded, it's not GPT-OSS | ||
| tokenizer.encode(token, add_special_tokens=False) |
There was a problem hiding this comment.
In what cases would this raise an exception? Afaik, tokenizer.encode will encode any string into a set of tokens.
Signed-off-by: Mustafa Eyceoz <[email protected]>
37da1e4 to
4011142
Compare
4011142 to
17e39e5
Compare
277aeb4 to
b2418c9
Compare
* addition of padded batch packer + simplified train loop * update tests + linting
b2418c9 to
3ab23f7
Compare
bc27542 to
99987ef
Compare
99987ef to
3913460
Compare
Adds support for GPT OSS:
Also cleans up our loss calculation by removing our forward override and using accurate batch stats