Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Specific programmatically generated files listed in the `exclude` field in [.pre
To check formatting:

```shell
make verify
make verify
```

## Testing
Expand Down Expand Up @@ -73,4 +73,4 @@ For any significant features or enhancement for Kubeflow SDK project we follow t
[Kubeflow Enhancement Proposal process](https://github.com/kubeflow/community/tree/master/proposals).

If you want to submit a significant change to the Kubeflow Trainer, please submit a new KEP under
[./docs/proposals](./docs/proposals/) directory.
[./docs/proposals](./docs/proposals/) directory.
2 changes: 1 addition & 1 deletion docs/proposals/2-trainer-local-execution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The proposed local execution mode will allow engineers to quickly test their mod

## Proposal

The local execution mode will allow users to run training jobs in container runtime environment on their local machines, mimicking the larger Kubeflow setup but without requiring Kubernetes.
The local execution mode will allow users to run training jobs in container runtime environment on their local machines, mimicking the larger Kubeflow setup but without requiring Kubernetes.

![Architecture Diagram](high-level-arch.svg)

Expand Down
2 changes: 2 additions & 0 deletions kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from kubeflow.trainer.types.types import (
BuiltinTrainer,
CustomTrainer,
DataCacheInitializer,
DataFormat,
DataType,
HuggingFaceDatasetInitializer,
Expand All @@ -44,6 +45,7 @@
__all__ = [
"BuiltinTrainer",
"CustomTrainer",
"DataCacheInitializer",
"DataFormat",
"DATASET_PATH",
"DataType",
Expand Down
56 changes: 52 additions & 4 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Callable, Optional
from typing import Callable, Optional, Union

from kubeflow.trainer.constants import constants

Expand Down Expand Up @@ -258,10 +258,58 @@ class TrainJob:
# TODO (andreyvelich): Discuss how to keep these configurations is sync with pkg.initializers.types
@dataclass
class HuggingFaceDatasetInitializer:
"""Configuration for downloading datasets from HuggingFace Hub."""

storage_uri: str
access_token: Optional[str] = None


@dataclass
class DataCacheInitializer:
"""Configuration for distributed data caching system for training workloads.

Args:
storage_uri (`str`): The URI for the cached data in the format
'cache://<SCHEMA_NAME>/<TABLE_NAME>'. This specifies the location
where the data cache will be stored and accessed.
metadata_loc (`str`): The metadata file path of an iceberg table.
num_data_nodes (`int`): The number of data nodes in the distributed cache
system. Must be greater than 1.
head_cpu (`Optional[str]`): The CPU resources to allocate for the cache head node.
head_mem (`Optional[str]`): The memory resources to allocate for the cache head node.
worker_cpu (`Optional[str]`): The CPU resources to allocate for each cache worker node.
worker_mem (`Optional[str]`): The memory resources to allocate for each cache worker node.
iam_role (`Optional[str]`): The IAM role to use for accessing metadata_loc file.
"""

storage_uri: str
metadata_loc: str
num_data_nodes: int
head_cpu: Optional[str] = None
head_mem: Optional[str] = None
worker_cpu: Optional[str] = None
worker_mem: Optional[str] = None
iam_role: Optional[str] = None

def __post_init__(self):
"""Validate DataCacheInitializer parameters."""
if self.num_data_nodes <= 1:
raise ValueError(f"num_data_nodes must be greater than 1, got {self.num_data_nodes}")

# Validate storage_uri format
if not self.storage_uri.startswith("cache://"):
raise ValueError(f"storage_uri must start with 'cache://', got {self.storage_uri}")

uri_path = self.storage_uri[len("cache://") :]
parts = uri_path.split("/")

if len(parts) != 2:
raise ValueError(
f"storage_uri must be in format "
f"'cache://<SCHEMA_NAME>/<TABLE_NAME>', got {self.storage_uri}"
)


# Configuration for the HuggingFace model initializer.
@dataclass
class HuggingFaceModelInitializer:
Expand All @@ -274,11 +322,11 @@ class Initializer:
"""Initializer defines configurations for dataset and pre-trained model initialization

Args:
dataset (`Optional[HuggingFaceDatasetInitializer]`): The configuration for one of the
supported dataset initializers.
dataset (`Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]]`):
The configuration for one of the supported dataset initializers.
model (`Optional[HuggingFaceModelInitializer]`): The configuration for one of the
supported model initializers.
"""

dataset: Optional[HuggingFaceDatasetInitializer] = None
dataset: Optional[Union[HuggingFaceDatasetInitializer, DataCacheInitializer]] = None
model: Optional[HuggingFaceModelInitializer] = None
115 changes: 115 additions & 0 deletions kubeflow/trainer/types/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase
from kubeflow.trainer.types import types


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="valid datacacheinitializer creation",
expected_status=SUCCESS,
config={
"storage_uri": "cache://test_schema/test_table",
"num_data_nodes": 3,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_output=None,
),
TestCase(
name="invalid num_data_nodes raises ValueError",
expected_status=FAILED,
config={
"storage_uri": "cache://test_schema/test_table",
"num_data_nodes": 1,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_error=ValueError,
),
TestCase(
name="zero num_data_nodes raises ValueError",
expected_status=FAILED,
config={
"storage_uri": "cache://test_schema/test_table",
"num_data_nodes": 0,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_error=ValueError,
),
TestCase(
name="negative num_data_nodes raises ValueError",
expected_status=FAILED,
config={
"storage_uri": "cache://test_schema/test_table",
"num_data_nodes": -1,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_error=ValueError,
),
TestCase(
name="invalid storage_uri without cache:// prefix raises ValueError",
expected_status=FAILED,
config={
"storage_uri": "invalid://test_schema/test_table",
"num_data_nodes": 3,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_error=ValueError,
),
TestCase(
name="invalid storage_uri format raises ValueError",
expected_status=FAILED,
config={
"storage_uri": "cache://test_schema",
"num_data_nodes": 3,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_error=ValueError,
),
TestCase(
name="invalid storage_uri with too many parts raises ValueError",
expected_status=FAILED,
config={
"storage_uri": "cache://test_schema/test_table/extra",
"num_data_nodes": 3,
"metadata_loc": "gs://my-bucket/metadata",
},
expected_error=ValueError,
),
],
)
def test_data_cache_initializer(test_case: TestCase):
"""Test DataCacheInitializer creation and validation."""
print("Executing test:", test_case.name)

try:
initializer = types.DataCacheInitializer(
storage_uri=test_case.config["storage_uri"],
num_data_nodes=test_case.config["num_data_nodes"],
metadata_loc=test_case.config["metadata_loc"],
)

assert test_case.expected_status == SUCCESS
# Only check the fields that were passed in config, not auto-generated ones
for key in test_case.config:
assert getattr(initializer, key) == test_case.config[key]

except Exception as e:
assert test_case.expected_status == FAILED
assert type(e) is test_case.expected_error
print("test execution complete")
70 changes: 47 additions & 23 deletions kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import fields
import inspect
import os
import textwrap
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from urllib.parse import urlparse

from kubeflow_trainer_api import models
Expand Down Expand Up @@ -563,34 +564,57 @@ def get_args_from_dataset_preprocess_config(


def get_dataset_initializer(
dataset: Optional[types.HuggingFaceDatasetInitializer] = None,
dataset: Optional[
Union[types.HuggingFaceDatasetInitializer, types.DataCacheInitializer]
] = None,
) -> Optional[models.TrainerV1alpha1DatasetInitializer]:
"""
Get the TrainJob dataset initializer from the given config.
"""
if not isinstance(dataset, types.HuggingFaceDatasetInitializer):
return None
if isinstance(dataset, types.HuggingFaceDatasetInitializer):
dataset_initializer = models.TrainerV1alpha1DatasetInitializer(
storageUri=(
dataset.storage_uri
if dataset.storage_uri.startswith("hf://")
else "hf://" + dataset.storage_uri
),
env=(
[
models.IoK8sApiCoreV1EnvVar(
name=constants.INITIALIZER_ENV_ACCESS_TOKEN,
value=dataset.access_token,
),
]
if dataset.access_token
else None
),
)
return dataset_initializer
elif isinstance(dataset, types.DataCacheInitializer):
# Build env vars from optional model fields
envs = []

# Add CLUSTER_SIZE env var from num_data_nodes required field
envs.append(
models.IoK8sApiCoreV1EnvVar(name="CLUSTER_SIZE", value=str(dataset.num_data_nodes + 1))
)

# TODO (andreyvelich): Support more parameters.
dataset_initializer = models.TrainerV1alpha1DatasetInitializer(
storageUri=(
dataset.storage_uri
if dataset.storage_uri.startswith("hf://")
else "hf://" + dataset.storage_uri
),
env=(
[
models.IoK8sApiCoreV1EnvVar(
name=constants.INITIALIZER_ENV_ACCESS_TOKEN,
value=dataset.access_token,
),
]
if dataset.access_token
else None
),
)
# Add METADATA_LOC env var from metadata_loc required field
envs.append(models.IoK8sApiCoreV1EnvVar(name="METADATA_LOC", value=dataset.metadata_loc))

return dataset_initializer
# Add env vars from optional fields (skip required fields)
required_fields = {"storage_uri", "metadata_loc", "num_data_nodes"}
for f in fields(dataset):
if f.name not in required_fields:
value = getattr(dataset, f.name)
if value is not None:
envs.append(models.IoK8sApiCoreV1EnvVar(name=f.name.upper(), value=value))

return models.TrainerV1alpha1DatasetInitializer(
storageUri=dataset.storage_uri, env=envs if envs else None
)
else:
return None


def get_model_initializer(
Expand Down
33 changes: 33 additions & 0 deletions kubeflow/trainer/utils/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,36 @@ def test_get_command_using_train_func(test_case: TestCase):
except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")


def test_get_dataset_initializer():
"""Test get_dataset_initializer uses DataCacheInitializer optional fields as env vars."""
datacache_initializer = types.DataCacheInitializer(
storage_uri="cache://test_schema/test_table",
num_data_nodes=3,
metadata_loc="s3://bucket/metadata",
head_cpu="1",
head_mem="1Gi",
worker_cpu="2",
worker_mem="2Gi",
iam_role="arn:aws:iam::123456789012:role/test-role",
)

dataset_initializer = utils.get_dataset_initializer(datacache_initializer)

assert dataset_initializer is not None
assert dataset_initializer.env is not None
env_dict = {env_var.name: env_var.value for env_var in dataset_initializer.env}

# Check CLUSTER_SIZE is present from num_data_nodes
assert env_dict["CLUSTER_SIZE"] == "4"

# Check METADATA_LOC is present from metadata_loc
assert env_dict["METADATA_LOC"] == "s3://bucket/metadata"

# Check all optional fields are present as uppercase env vars
assert env_dict["HEAD_CPU"] == "1"
assert env_dict["HEAD_MEM"] == "1Gi"
assert env_dict["WORKER_CPU"] == "2"
assert env_dict["WORKER_MEM"] == "2Gi"
assert env_dict["IAM_ROLE"] == "arn:aws:iam::123456789012:role/test-role"