8000 [SDK] Use torchrun to create PyTorchJob from function (kubeflow/train… · szaher/sdk@c6f7a83 · GitHub
[go: up one dir, main page]

Skip to content

Commit c6f7a83

Browse files
andreyvelichszaher
authored andcommitted
[SDK] Use torchrun to create PyTorchJob from function (kubeflow/trainer#2276)
* [SDK] Use torchrun to create PyTorchJob from function Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Update PyTorchJob SDK example Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Add consts for entrypoint Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> * Add check for num procs per worker Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> --------- Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
1 parent 122f285 commit c6f7a83

File tree

4 files changed

+236
-129
lines changed

4 files changed

+236
-129
lines changed

python/kubeflow/training/api/training_client.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def train(
128128
namespace: Namespace for the PyTorchJob. By default namespace is taken from
129129
`TrainingClient` object.
130130
num_workers: Number of PyTorchJob workers.
131-
num_procs_per_worker: Number of processes per PyTorchJob worker for `torchrun` CLI.
132-
You can use this parameter if you want to use more than 1 GPU per PyTorchJob worker.
131+
num_procs_per_worker: Number of processes per PyTorchJob worker for `torchrun` CLI. You
132+
should use this parameter if you want to use more than 1 GPU per PyTorchJob worker.
133133
resources_per_worker: A parameter that lets you specify how much
134134
resources each PyTorchJob worker container should have. You can either specify a
135135
kubernetes.client.V1ResourceRequirements object (documented here:
@@ -322,7 +322,8 @@ def create_job(
322322
base_image: Optional[str] = None,
323323
train_func: Optional[Callable] = None,
324324
parameters: Optional[Dict[str, Any]] = None,
325-
num_workers: Optional[int] = None,
325+
num_workers: Optional[int] = 1,
326+
num_procs_per_worker: Optional[Union[int, str]] = None,
326327
resources_per_worker: Union[dict, models.V1ResourceRequirements, None] = None,
327328
num_chief_replicas: Optional[int] = None,
328329
num_ps_replicas: Optional[int] = None,
@@ -355,6 +356,9 @@ def create_job(
355356
set, Base Image must support `bash` CLI to execute the training script.
356357
parameters: Dict of input parameters that training function might receive.
357358
num_workers: Number of Worker replicas for the Job.
359+
num_procs_per_worker: Number of processes per PyTorchJob worker for `torchrun` CLI. You
360+
should use this parameter if you want to use more than 1 GPU per PyTorchJob worker.
361+
Set to "auto" to automatically use available GPU/CPU PyTorch resources.
358362
resources_per_worker: A parameter that lets you specify how much
359363
resources each Worker container should have. You can either specify a
360364
kubernetes.client.V1ResourceRequirements object (documented here:
@@ -393,7 +397,8 @@ def create_job(
393397
if job is not None:
394398
for key, value in locals().items():
395399
if (
396-
key not in ["self", "job", "namespace", "pip_index_url"]
400+
key
401+
not in ["self", "job", "namespace", "pip_index_url", "num_workers"]
397402
and value is not None
398403
):
399404
raise ValueError(
@@ -419,19 +424,44 @@ def create_job(
419424
"Job name must be set to configure Job from function or image"
420425
)
421426

427+
# Check if at least one Worker is set.
428+
# TODO (andreyvelich): Remove this check once we have CEL validation.
429+
# Ref: https://github.com/kubeflow/training-operator/issues/1708
430+
if num_workers is None or num_workers < 1:
431+
raise ValueError(f"At least one Worker for {job_kind} must be set")
432+
422433
# Assign the default base image.
423434
# TODO (andreyvelich): Add base image for other Job kinds.
424435
if base_image < F440 span class=pl-c1>is None:
425436
base_image = constants.JOB_PARAMETERS[job_kind]["base_image"]
426437

438+
# By default we don't set command and args for the training container.
439+
command, args = None, None
440+
441+
# If training function is set get the command and args.
442+
if train_func is not None:
443+
# Use `torchrun` for distributed PyTorch training, otherwise use `python`
444+
if job_kind == constants.PYTORCHJOB_KIND and (
445+
num_workers > 1 or num_procs_per_worker is not None
446+
):
447+
entrypoint = constants.ENTRYPOINT_TORCH
448+
else:
449+
entrypoint = constants.ENTRYPOINT_PYTHON
450+
451+
command, args = utils.get_command_using_train_func(
452+
train_func=train_func,
453+
entrypoint=entrypoint,
454+
train_func_parameters=parameters,
455+
packages_to_install=packages_to_install,
456+
pip_index_url=pip_index_url,
457+
)
458+
427459
# Get Training Container template.
428460
container_spec = utils.get_container_spec(
429461
name=constants.JOB_PARAMETERS[job_kind]["container"],
430462
base_image=base_image,
431-
train_func=train_func,
432-
train_func_parameters=parameters,
433-
packages_to_install=packages_to_install,
434-
pip_index_url=pip_index_url,
463+
command=command,
464+
args=args,
435465
resources=resources_per_worker,
436466
)
437467

@@ -443,6 +473,10 @@ def create_job(
443473
# Configure template for different Jobs.
444474
# TODO (andreyvelich): Add support for other kinds (e.g. MPIJob).
445475
if job_kind == constants.TFJOB_KIND:
476+
if num_procs_per_worker is not None:
477+
raise ValueError(
478+
f"num_procs_per_worker can't be set for {constants.TFJOB_KIND}"
479+
)
446480
job = utils.get_tfjob_template(
447481
name=name,
448482
namespace=namespace,
@@ -451,12 +485,18 @@ def create_job(
451485
num_chief_replicas=num_chief_replicas,
452486
num_ps_replicas=num_ps_replicas,
453487
)
454-
elif job_kind == constants.PYTORCHJOB_KIND and num_workers:
488+
elif job_kind == constants.PYTORCHJOB_KIND:
489+
if num_chief_replicas is not None or num_ps_replicas is not None:
490+
raise ValueError(
491+
"num_chief_replicas and num_ps_replicas can't be set for "
492+
f"{constants.PYTORCHJOB_KIND}"
493+
)
455494
job = utils.get_pytorchjob_template(
456495
name=name,
457496
namespace=namespace,
458497
worker_pod_template_spec=pod_template_spec,
459498
num_workers=num_workers,
499+
num_procs_per_worker=num_procs_per_worker,
460500
)
461501
else:
462502
raise ValueError(

0 commit comments

Comments
 (0)
0