Expose AdamW optimizer parameters in training API#674
Conversation
WalkthroughThis PR adds configurable AdamW optimizer hyperparameters across the training stack: new TrainingArgs fields, new CLI flags in the training launcher, and extended optimizer setup to accept and apply weight_decay, betas, and eps (with defaults). Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.py(2 hunks)src/instructlab/training/main_ds.py(3 hunks)src/instructlab/training/model.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: unit: 3.12 on ubuntu-latest
- GitHub Check: unit: 3.13 on ubuntu-latest
- GitHub Check: unit: 3.11 on ubuntu-latest
- GitHub Check: pylint
🔇 Additional comments (7)
src/instructlab/training/config.py (2)
9-9: LGTM: Import addition is correct.The
Tupleimport is necessary for the type annotation ofadamw_betasfield added below.
213-224: Defaultbeta2=0.95for AdamW is appropriate for LLM training.The configuration uses
beta2=0.95, which differs from PyTorch's default of0.999but aligns with standard practices for LLM training. Llama 3 and Llama 3.2 training configurations officially use this same value for improved training stability. No changes needed.src/instructlab/training/main_ds.py (3)
422-424: LGTM: AdamW parameters correctly wired to optimizer setup.The new AdamW hyperparameters are properly extracted from the parsed arguments and passed to
setup_optimizer(). The tuple unpacking of beta1 and beta2 is correct.
532-535: LGTM: Command builder correctly propagates AdamW hyperparameters.The command construction properly extracts
beta1andbeta2from theadamw_betastuple and formats all four AdamW parameters as CLI arguments for the subprocess.
827-850: LGTM: CLI arguments properly defined with consistent defaults.The new CLI arguments are well-documented with help text and defaults that match the
TrainingArgsconfiguration inconfig.py.src/instructlab/training/model.py (2)
515-516: LGTM: Function signature properly extended.The new
weight_decayandepsparameters are correctly added with appropriate defaults that match theTrainingArgsconfiguration.
526-527: LGTM: Docstring updated to document new parameters.The documentation correctly describes the new
weight_decayandepsparameters.
Add support for configuring weight_decay, betas, and eps parameters for the AdamW optimizer through TrainingArgs, allowing users to tune these hyperparameters when calling run_training(). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
32fc10d to
770b8c2
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/instructlab/training/config.py (1)
213-224: Consider adding field validators for optimizer parameters.The new TrainingArgs fields are well-structured with clear descriptions and consistent defaults. However, consider adding validators to ensure:
adamw_betasvalues are in the range (0, 1)adamw_epsis positiveadamw_weight_decayis non-negativeAdditionally, the same concern applies here: the default
adamw_betas=(0.9, 0.95)uses beta2=0.95 instead of the PyTorch standard of 0.999.Example validator:
from pydantic import field_validator @field_validator('adamw_betas') def validate_betas(cls, v): if not (0 < v[0] < 1 and 0 < v[1] < 1): raise ValueError('Beta values must be in range (0, 1)') return v @field_validator('adamw_eps') def validate_eps(cls, v): if v <= 0: raise ValueError('Epsilon must be positive') return v @field_validator('adamw_weight_decay') def validate_weight_decay(cls, v): if v < 0: raise ValueError('Weight decay must be non-negative') return v
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.py(2 hunks)src/instructlab/training/main_ds.py(3 hunks)src/instructlab/training/model.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: unit: 3.12 on ubuntu-latest
- GitHub Check: unit: 3.11 on ubuntu-latest
- GitHub Check: unit: 3.13 on ubuntu-latest
- GitHub Check: pylint
- GitHub Check: Summary
🔇 Additional comments (6)
src/instructlab/training/model.py (2)
509-527: LGTM! Past review comment addressed.The function signature now includes
weight_decayandepsparameters, enabling AdamW optimizer configuration. The defaults (weight_decay=0.0, eps=1e-8) are conservative and appropriate.
563-571: Optimizer factory correctly passes consistent parameters across all three optimizer types.The uniform application of
weight_decayandepsto AdamW, FusedAdam, and DeepSpeedCPUAdam via functools.partial is valid—all three optimizers support these parameters with compatible signatures.src/instructlab/training/main_ds.py (3)
417-425: LGTM! AdamW parameters correctly wired.The optimizer setup correctly passes the AdamW hyperparameters with betas properly constructed from the individual beta1 and beta2 CLI arguments.
532-536: LGTM! CLI arguments correctly constructed.The torchrun command builder correctly includes all AdamW parameters, properly decomposing the betas tuple into individual beta1 and beta2 arguments.
827-850: Theadamw_*parameters use non-standard defaults that reflect language model training practices rather than PyTorch defaults. The beta2=0.95 choice (vs. PyTorch's 0.999) and weight_decay=0.0 align with modern LLM training best practices used in frameworks like Composer and recent research. Consider adding documentation explaining why these LLM-optimized defaults were chosen if this choice isn't already documented elsewhere in the codebase.src/instructlab/training/config.py (1)
9-9: LGTM! Necessary import for tuple type.Added
Tupleimport to support theadamw_betasfield type annotation.
Summary
TrainingArgsadamw_weight_decay,adamw_betas,adamw_epsrun_training()Changes
config.py: Added 3 new fields toTrainingArgsmodel.py: Updatedsetup_optimizer()to acceptweight_decayandepsparametersmain_ds.py: Added CLI arguments and wired them through the command builderUsage
Test plan
TrainingArgsimports and instantiates correctly with new fieldssetup_optimizer()signature has new parameters--helpshows new arguments🤖 Generated with Claude Code
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.