fix(torchrun): Omit empty arguments and correct nproc_per_node type#661
fix(torchrun): Omit empty arguments and correct nproc_per_node type#661RobotSail merged 20 commits intoinstructlab:mainfrom
Conversation
| """ | ||
|
|
||
| nproc_per_node: int | ||
| nproc_per_node: str |
There was a problem hiding this comment.
Did you mean to make this change?
| # build args for this file. Ignore empty or unset values except int values | ||
| for key, value in train_args.model_dump(exclude_none=True).items(): | ||
| # avoid ignoring int attrs with value = 0 | ||
| if not isinstance(value, int) and (not value or value == ""): |
There was a problem hiding this comment.
How would this handle booleans?
There was a problem hiding this comment.
I have updated this one to only check for string types.
| # avoid ignoring int attrs with value = 0 | ||
| if not isinstance(value, int) and (not value or value == ""): | ||
| continue | ||
| command.append(f"--{key}={value}") |
There was a problem hiding this comment.
Have you verified that all of our CLI arguments are perfectly 1:1 with the variable names we're using here?
There was a problem hiding this comment.
I have updated this one to only process torchrun args and leave the scripts args as they're not perfectly 1:1 mapped.
| # this will tell the model construct to ignore | ||
| # extra arguments that aren't part of this model | ||
| class Config: | ||
| extra = "ignore" |
There was a problem hiding this comment.
@szaher Do you know when this would be the case? If our goal here is to dynamically build the torchrun command using the defined interface, this seems like it now opens the floor up for users to pass invalid arguments through torchrun. This means that any incorrect interface usage wouldn't be detected until runtime.
There was a problem hiding this comment.
In fact this will actually drop additionally provided arguments and only keep torchrun ones
torchrun_defaults = {
'nnodes': 1, 'node_rank': 0, 'rdzv_id': 0, 'rdzv_endpoint': '',
'nproc_per_node': 2, "fake_arg": "what"
}
y = TorchrunArgs(**torchrun_defaults)
print(y)
TorchrunArgs(nproc_per_node=2, nnodes=1, node_rank=0, rdzv_id=0, rdzv_endpoint='')
RobotSail
left a comment
There was a problem hiding this comment.
A few comments but otherwise great work! LGTM !!
|
@szaher Looks like you will also need to rebase this PR. |
The command generation logic is updated to dynamically build the torchrun command, excluding arguments that are empty or None. This prevents them from overriding environment variables, ensuring that torchrun can correctly inherit its configuration. An exception is made for integer arguments where 0 is a valid value. Additionally, the nproc_per_node argument type has been changed from int to str to support special values accepted by PyTorch, such as 'auto', 'gpu', and 'cpu'. Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88 Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Oleg Silkin <[email protected]>
ffff971 to
27ff594
Compare
|
LGTM, will merge once tests pass. |
…oint are provided Signed-off-by: Oleg Silkin <[email protected]>
The command generation logic is updated to dynamically build the torchrun command, excluding arguments that are empty or None. This prevents them from overriding environment variables, ensuring that torchrun can
correctly inherit its configuration. An exception is made for integer arguments where 0 is a valid value.
Additionally, the nproc_per_node argument type has been changed from int to str to support special values
accepted by PyTorch, such as 'auto', 'gpu', and 'cpu'.
Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88