Fix HYBRID_SHARD failure when world_size < available GPUs#682
Fix HYBRID_SHARD failure when world_size < available GPUs#682RobotSail merged 1 commit intoinstructlab:mainfrom
Conversation
When using FSDP with HYBRID_SHARD sharding strategy, FSDP1 auto-detects num_devices_per_node from torch.cuda.device_count(). It then tries to create intra-node process groups of that size, which fails when world_size < num_devices_per_node with: ValueError: The arg 'group_size' (8) must not exceed the world size (2) This fix detects when HYBRID_SHARD would fail due to this constraint and falls back to FULL_SHARD with a warning, allowing training to proceed on systems with fewer GPUs than available. Fixes instructlab#678
📝 WalkthroughWalkthroughThis change adds a runtime guard to the FSDP configuration that detects when Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Important Action Needed: IP Allowlist UpdateIf your organization protects your Git platform with IP whitelisting, please add the new CodeRabbit IP address to your allowlist:
Failure to add the new IP will result in interrupted reviews. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary
Fixes FSDP training failure when using
HYBRID_SHARDsharding strategy on systems whereworld_sizeis less than the number of available GPUs per node.Problem
When running with
--fsdp_sharding_strategy=HYBRID_SHARDon a system with 8 GPUs but only using 2 processes (e.g.,nproc_per_node=2), FSDP fails with:This happens because FSDP1 auto-detects
num_devices_per_nodefromtorch.cuda.device_count()and tries to create intra-node process groups of that size.Solution
Detect when
HYBRID_SHARDwould fail due toworld_size < num_devices_per_nodeand automatically fall back toFULL_SHARDwith a warning message.Changes
get_fsdp_config()to compareworld_sizewithtorch.cuda.device_count()FULL_SHARDwhenHYBRID_SHARDwould failTesting
Tested the fix scenario where
world_size=2anddevice_count=8:ValueError: The arg 'group_size' (8) must not exceed the world size (2)FULL_SHARDand warning loggedFixes #678
Summary by CodeRabbit
Bug Fixes