|
16 | 16 | import logging
|
17 | 17 | import multiprocessing
|
18 | 18 | import time
|
19 |
| -from typing import Any, Callable, Dict, List, Optional, Union |
| 19 | +from typing import Any, Callable, Dict, List, Optional, TypedDict, Union |
20 | 20 |
|
21 | 21 | import grpc
|
22 | 22 | import kubeflow.katib.katib_api_pb2 as katib_api_pb2
|
|
42 | 42 |
|
43 | 43 | logger = logging.getLogger(__name__)
|
44 | 44 |
|
| 45 | +TuneStoragePerTrialType = TypedDict( |
| 46 | + "TuneStoragePerTrial", |
| 47 | + { |
| 48 | + "volume": Union[client.V1Volume, Dict[str, Any]], |
| 49 | + "mount_path": Union[str, client.V1VolumeMount], |
| 50 | + }, |
| 51 | +) |
| 52 | + |
45 | 53 |
|
46 | 54 | class KatibClient(object):
|
47 | 55 | def __init__(
|
@@ -198,6 +206,7 @@ def tune(
|
198 | 206 | env_per_trial: Optional[
|
199 | 207 | Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
|
200 | 208 | ] = None,
|
| 209 | + storage_per_trial: Optional[List[TuneStoragePerTrialType]] = None, |
201 | 210 | algorithm_name: str = "random",
|
202 | 211 | algorithm_settings: Union[
|
203 | 212 | dict, List[models.V1beta1AlgorithmSetting], None
|
@@ -288,6 +297,21 @@ class name in this argument.
|
288 | 297 | https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
|
289 | 298 | or a kubernetes.client.models.V1EnvFromSource (documented here:
|
290 | 299 | https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
|
| 300 | + storage_per_trial: List of storage configurations for each trial container. |
| 301 | + Each element in the list should be a dictionary with two keys: |
| 302 | + - volume: Either a kubernetes.client.V1Volume object or a dictionary |
| 303 | + containing volume configuration with required fields: |
| 304 | + - name: Name of the volume |
| 305 | + - type: One of "pvc", "secret", "config_map", or "empty_dir" |
| 306 | + Additional fields based on volume type: |
| 307 | + - For pvc: claim_name, read_only (optional) |
| 308 | + - For secret: secret_name, items (optional), default_mode (optional), |
| 309 | + optional (optional) |
| 310 | + - For config_map: config_map_name, items (optional), default_mode |
| 311 | + (optional), optional (optional) |
| 312 | + - For empty_dir: medium (optional), size_limit (optional) |
| 313 | + - mount_path: Either a kubernetes.client.V1VolumeMount object or a string |
| 314 | + specifying the path where the volume should be mounted in the container |
291 | 315 | algorithm_name: Search algorithm for the HyperParameter tuning.
|
292 | 316 | algorithm_settings: Settings for the search algorithm given.
|
293 | 317 | For available fields, check this doc:
|
@@ -503,6 +527,134 @@ class name in this argument.
|
503 | 527 | container_spec.env = env if env else None
|
504 | 528 | container_spec.env_from = env_from if env_from else None
|
505 | 529 |
|
| 530 | + volumes: List[client.V1Volume] = [] |
| 531 | + volume_mounts: List[client.V1VolumeMount] = [] |
| 532 | + if storage_per_trial: |
| 533 | + if isinstance(storage_per_trial, dict): |
| 534 | + storage_per_trial = [storage_per_trial] |
| 535 | + for storage in storage_per_trial: |
| 536 | + print(f"storage: {storage}") |
| 537 | + volume = None |
| 538 | + if isinstance(storage["volume"], client.V1Volume): |
| 539 | + volume = storage["volume"] |
| 540 | + elif isinstance(storage["volume"], dict): |
| 541 | + volume_name = storage["volume"].get("name") |
| 542 | + volume_type = storage["volume"].get("type") |
| 543 | + |
| 544 | + if not volume_name: |
| 545 | + raise ValueError( |
| 546 | + "storage_per_trial['volume'] does not have a 'name' key" |
| 547 | + ) |
| 548 | + if not volume_type: |
| 549 | + raise ValueError( |
| 550 | + "storage_per_trial['volume'] does not have a 'type' key" |
| 551 | + ) |
| 552 | + |
| 553 | + if volume_type == "pvc": |
| 554 | + volume_claim_name = storage["volume"].get("claim_name") |
| 555 | + if not volume_claim_name: |
| 556 | + raise ValueError( |
| 557 | + "storage_per_trial['volume'] should have a " |
| 558 | + "'claim_name' key for type pvc" |
| 559 | + ) |
| 560 | + volume = client.V1Volume( |
| 561 | + name=volume_name, |
| 562 | + persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( |
| 563 | + claim_name=volume_claim_name, |
| 564 | + read_only=storage["volume"].get("read_only", False), |
| 565 | + ), |
| 566 | + ) |
| 567 | + elif volume_type == "secret": |
| 568 | + volume = client.V1Volume( |
| 569 | + name=volume_name, |
| 570 | + secret=client.V1SecretVolumeSource( |
| 571 | + secret_name=storage["volume"].get("secret_name"), |
| 572 | + items=storage["volume"].get("items", None), |
| 573 | + default_mode=storage["volume"].get( |
| 574 | + "default_mode", None |
| 575 | + ), |
| 576 | + optional=storage["volume"].get("optional", False), |
| 577 | + ), |
| 578 | + ) |
| 579 | + elif volume_type == "config_map": |
| 580 | + volume = client.V1Volume( |
| 581 | + name=volume_name, |
| 582 | + config_map=client.V1ConfigMapVolumeSource( |
| 583 | + name=storage["volume"].get("config_map_name"), |
| 584 | + items=storage["volume"].get("items", []), |
| 585 | + default_mode=storage["volume"].get( |
| 586 | + "default_mode", None |
| 587 | + ), |
| 588 | + optional=storage["volume"].get("optional", False), |
| 589 | + ), |
| 590 | + ) |
| 591 | + elif volume_type == "empty_dir": |
| 592 | + volume = client.V1Volume( |
| 593 | + name=volume_name, |
| 594 | + empty_dir=client.V1EmptyDirVolumeSource( |
| 595 | + medium=storage["volume"].get("medium", None), |
| 596 | + size_limit=storage["volume"].get( |
| 597 | + "size_limit", None |
| 598 | + ), |
| 599 | + ), |
| 600 | + ) |
| 601 | + else: |
| 602 | + raise ValueError( |
| 603 | + "storage_per_trial['volume'] must be a client.V1Volume or a dict" |
| 604 | + ) |
| 605 | + |
| 606 | + else: |
| 607 | + raise ValueError( |
| 608 | + "storage_per_trial['volume'] must be a client.V1Volume or a dict" |
| 609 | + ) |
| 610 | + |
| 611 | + volumes.append(volume) |
| 612 | + |
| 613 | + if isinstance(storage["mount_path"], client.V1VolumeMount): |
| 614 | + volume_mounts.append(storage["mount_path"]) |
| 615 | + elif isinstance(storage["mount_path"], str): |
| 616 | + volume_mounts.append( |
| 617 | + client.V1VolumeMount( |
| 618 | + name=volume_name, mount_path=storage["mount_path"] |
| 619 | + ) |
| 620 | + ) |
| 621 | + else: |
| 622 | + raise ValueError( |
| 623 | + "storage_per_trial['mount_path'] must be a " |
| 624 | + "client.V1VolumeMount or a str" |
| 625 | + ) |
| 626 | + |
| 627 | + # Create Trial specification. |
| 628 | + trial_spec = client.V1Job( |
| 629 | + api_version="batch/v1", |
| 630 | + kind="Job", |
| 631 | + spec=client.V1JobSpec( |
| 632 | + template=client.V1PodTemplateSpec( |
| 633 | + metadata=models.V1ObjectMeta( |
| 634 | + annotations={"sidecar.istio.io/inject": "false"} |
| 635 | + ), |
| 636 | + spec=client.V1PodSpec( |
| 637 | + restart_policy="Never", |
| 638 | + containers=[ |
| 639 | + client.V1Container( |
| 640 | + name=constants.DEFAULT_PRIMARY_CONTAINER_NAME, |
| 641 | + image=base_image, |
| 642 | + command=["bash", "-c"], |
| 643 | + args=[exec_script], |
| 644 | + env=env if env else None, |
| 645 | + env_from=env_from if env_from else None, |
| 646 | + resources=resources_per_trial, |
| 647 | + volume_mounts=( |
| 648 | + volume_mounts if volume_mounts else None |
| 649 | + ), |
| 650 | + ) |
| 651 | + ], |
| 652 | + volumes=volumes if volumes else None, |
| 653 | + ), |
| 654 | + ) |
| 655 | + ), |
| 656 | + ) |
| 657 | + |
506 | 658 | # Trial uses PyTorchJob for distributed training if TrainerResources is set.
|
507 | 659 | if isinstance(resources_per_trial, TrainerResources):
|
508 | 660 | trial_template = utils.get_trial_template_with_pytorchjob(
|
@@ -584,7 +736,7 @@ class name in this argument.
|
584 | 736 | f"It must also start and end with an alphanumeric character."
|
585 | 737 | )
|
586 | 738 | elif hasattr(e, "status") and e.status == 409:
|
587 |
| - print(f"PVC '{name}' already exists in namespace " f"{namespace}.") |
| 739 | + print(f"PVC '{name}' already exists in namespace {namespace}.") |
588 | 740 | else:
|
589 | 741 | raise RuntimeError(f"failed to create PVC. Error: {e}")
|
590 | 742 |
|
|
0 commit comments