diff --git a/.flake8 b/.flake8 index 3df6c7579f..6e703f7c4c 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 100 -extend-ignore = W503 +extend-ignore = W503, E203 diff --git a/cmd/initializers/dataset/requirements.txt b/cmd/initializers/dataset/requirements.txt index fa32ab4069..e5cb789750 100644 --- a/cmd/initializers/dataset/requirements.txt +++ b/cmd/initializers/dataset/requirements.txt @@ -1 +1,2 @@ huggingface-hub>=0.27.0,<0.28 +kubernetes>=27.2.0 diff --git a/pkg/initializers/dataset/__main__.py b/pkg/initializers/dataset/__main__.py index ecaf21e972..f36c3cdb12 100644 --- a/pkg/initializers/dataset/__main__.py +++ b/pkg/initializers/dataset/__main__.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse import pkg.initializers.utils.utils as utils +from pkg.initializers.dataset.cache import CacheInitializer from pkg.initializers.dataset.huggingface import HuggingFace logging.basicConfig( @@ -27,6 +28,10 @@ def main(): hf = HuggingFace() hf.load_config() hf.download_dataset() + case utils.CACHE_SCHEME: + cache = CacheInitializer() + cache.load_config() + cache.download_dataset() case _: logging.error("STORAGE_URI must have the valid dataset provider") raise Exception diff --git a/pkg/initializers/dataset/cache.py b/pkg/initializers/dataset/cache.py new file mode 100644 index 0000000000..90202545b1 --- /dev/null +++ b/pkg/initializers/dataset/cache.py @@ -0,0 +1,285 @@ +import logging +import time + +from kubernetes import client, config +from kubernetes.client.rest import ApiException +from kubernetes.dynamic.exceptions import ConflictError + +import pkg.initializers.types.types as types +import pkg.initializers.utils.utils as utils + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%dT%H:%M:%SZ", + level=logging.INFO, +) + + +def get_namespace() -> str: + """Get the current namespace from the service account token.""" + try: + with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") as f: + return f.readline().strip() + except FileNotFoundError: + logging.warning( + "Service account namespace file not found, using 'default' namespace" + ) + return "default" + + +class CacheInitializer(utils.DatasetProvider): + + def load_config(self): + config_dict = utils.get_config_from_env(types.CacheDatasetInitializer) + self.config = types.CacheDatasetInitializer(**config_dict) + + # Parse schema_name and table_name from storage_uri + # Format: cache:/// + uri_path = self.config.storage_uri[len("cache://") :] + parts = uri_path.split("/") + self.schema_name = parts[0] + self.table_name = parts[1] + + def download_dataset(self): + """Bootstrap cache cluster with dataset""" + logging.info( + f"Cache initializer called with storage URI: {self.config.storage_uri}" + ) + + train_job_name = self.config.train_job_name + cache_image = self.config.cache_image + cluster_size = int(self.config.cluster_size) + iam_role = self.config.iam_role + head_cpu = self.config.head_cpu + head_mem = self.config.head_mem + worker_cpu = self.config.worker_cpu + worker_mem = self.config.worker_mem + namespace = get_namespace() + metadata_loc = self.config.metadata_loc + table_name = self.table_name + schema_name = self.schema_name + + # Load Kubernetes configuration + config.load_incluster_config() + + api_client = client.ApiClient() + core_v1 = client.CoreV1Api(api_client) + custom_api = client.CustomObjectsApi(api_client) + + # Get TrainJob for owner reference + try: + training_job = custom_api.get_namespaced_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + plural="trainjobs", + namespace=namespace, + name=train_job_name, + ) + logging.info(f"TrainJob: {training_job}") + + # Create owner reference dictionary + logging.info( + f"Creating owner reference from TrainJob: {training_job['metadata']['name']}" + ) + + owner_ref_dict = { + "apiVersion": training_job["apiVersion"], + "kind": training_job["kind"], + "name": training_job["metadata"]["name"], + "uid": training_job["metadata"]["uid"], + "controller": True, + "blockOwnerDeletion": True, + } + + logging.info( + f"Owner reference created with apiVersion='{training_job['apiVersion']}', " + f"kind='{training_job['kind']}')" + ) + except ApiException as e: + logging.error(f"Failed to get TrainJob {train_job_name}: {e}") + return + + try: + # Create ServiceAccount + service_account = client.V1ServiceAccount( + metadata=client.V1ObjectMeta( + name=f"{train_job_name}-cache", + namespace=namespace, + annotations={ + "eks.amazonaws.com/sts-regional-endpoints": "true", + "eks.amazonaws.com/role-arn": iam_role, + }, + owner_references=[owner_ref_dict], + ) + ) + + try: + core_v1.create_namespaced_service_account( + namespace=namespace, body=service_account + ) + logging.info(f"Created ServiceAccount {service_account.metadata.name}") + except ApiException as e: + if e.status == 409: + logging.info( + f"ServiceAccount {service_account.metadata.name} " + f"already exists, skipping creation" + ) + else: + raise e + + # Prepare environment variables + env_vars = [] + if metadata_loc: + env_vars.append({"name": "METADATA_LOC", "value": metadata_loc}) + if table_name: + env_vars.append({"name": "TABLE_NAME", "value": table_name}) + if schema_name: + env_vars.append({"name": "SCHEMA_NAME", "value": schema_name}) + + # Create LeaderWorkerSet + lws_body = { + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "kind": "LeaderWorkerSet", + "metadata": { + "name": f"{train_job_name}-cache", + "namespace": namespace, + "ownerReferences": [owner_ref_dict], + }, + "spec": { + "replicas": 1, + "leaderWorkerTemplate": { + "size": cluster_size, + "leaderTemplate": { + "metadata": { + "labels": {"app": f"{train_job_name}-cache-head"} + }, + "spec": { + "serviceAccountName": service_account.metadata.name, + "containers": [ + { + "name": "head", + "image": cache_image, + "command": ["head"], + "args": ["0.0.0.0", "50051"], + "resources": { + "limits": { + "cpu": head_cpu, + "memory": head_mem, + }, + "requests": { + "cpu": head_cpu, + "memory": head_mem, + }, + }, + "env": env_vars, + "ports": [{"containerPort": 50051}], + } + ], + }, + }, + "workerTemplate": { + "spec": { + "serviceAccountName": f"{train_job_name}-cache", + "containers": [ + { + "name": "worker", + "image": cache_image, + "command": ["worker"], + "args": ["0.0.0.0", "50051"], + "resources": { + "limits": { + "cpu": worker_cpu, + "memory": worker_mem, + }, + "requests": { + "cpu": worker_cpu, + "memory": worker_mem, + }, + }, + "env": env_vars, + "ports": [{"containerPort": 50051}], + } + ], + } + }, + }, + }, + } + + # Create LeaderWorkerSet + custom_api.create_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=namespace, + plural="leaderworkersets", + body=lws_body, + ) + logging.info(f"Created LeaderWorkerSet {lws_body['metadata']['name']}") + + # Create Service + service = client.V1Service( + metadata=client.V1ObjectMeta( + name=f"{train_job_name}-cache-service", + namespace=namespace, + owner_references=[owner_ref_dict], + ), + spec=client.V1ServiceSpec( + selector={"app": f"{train_job_name}-cache-head"}, + ports=[ + client.V1ServicePort( + protocol="TCP", port=50051, target_port=50051 + ) + ], + ), + ) + + try: + core_v1.create_namespaced_service(namespace=namespace, body=service) + logging.info(f"Created Service {service.metadata.name}") + except ApiException as e: + if e is ConflictError: + logging.info( + f"Service {service.metadata.name} already exists, " + f"skipping creation" + ) + else: + raise e + + # Wait for LeaderWorkerSet to become ready + # TODO:// refactor to use watch API + while True: + try: + lws = custom_api.get_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + plural="leaderworkersets", + name=lws_body["metadata"]["name"], + namespace=namespace, + ) + + conditions = lws.get("status", {}).get("conditions", []) + if any( + c["type"] == "Available" and c["status"] == "True" + for c in conditions + ): + logging.info( + f"LeaderWorkerSet {lws_body['metadata']['name']} is ready" + ) + break + + time.sleep(5) + except ApiException as e: + raise e + + except ApiException as e: + logging.error(f"Cache cluster creation failed: {e}") + # Cleanup on failure + try: + core_v1.delete_namespaced_service_account( + name=f"{train_job_name}-cache", namespace=namespace + ) + except Exception as cleanup_error: + logging.error(f"Error cleaning up ServiceAccount: {cleanup_error}") + return + + logging.info("Cache cluster creation completed") diff --git a/pkg/initializers/dataset/cache_test.py b/pkg/initializers/dataset/cache_test.py new file mode 100644 index 0000000000..983e035da5 --- /dev/null +++ b/pkg/initializers/dataset/cache_test.py @@ -0,0 +1,221 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import pkg.initializers.utils.utils as utils +from pkg.initializers.dataset.cache import CacheInitializer + + +# Test cases for config loading +@pytest.mark.parametrize( + "test_name, test_config, expected", + [ + ( + "Full config with all values", + { + "storage_uri": "cache://test_schema/test_table", + "train_job_name": "custom-job", + "cache_image": "custom-image:latest", + "cluster_size": "5", + "metadata_loc": "s3://bucket/metadata", + "iam_role": "arn:aws:iam::123456789012:role/custom-role", + "head_cpu": "4", + "head_mem": "8Gi", + "worker_cpu": "8", + "worker_mem": "16Gi", + }, + { + "storage_uri": "cache://test_schema/test_table", + "train_job_name": "custom-job", + "cache_image": "custom-image:latest", + "cluster_size": "5", + "metadata_loc": "s3://bucket/metadata", + "iam_role": "arn:aws:iam::123456789012:role/custom-role", + "head_cpu": "4", + "head_mem": "8Gi", + "worker_cpu": "8", + "worker_mem": "16Gi", + }, + ), + ( + "Minimal config with only storage_uri", + { + "storage_uri": "cache://minimal_schema/minimal_table", + "train_job_name": "minimal-job", + "cache_image": "minimal-image:latest", + "iam_role": "arn:aws:iam::123456789012:role/minimal-role", + "metadata_loc": "s3://minimal-bucket/metadata", + }, + { + "storage_uri": "cache://minimal_schema/minimal_table", + "train_job_name": "minimal-job", + "cache_image": "minimal-image:latest", + "cluster_size": "3", + "metadata_loc": "s3://minimal-bucket/metadata", + "iam_role": "arn:aws:iam::123456789012:role/minimal-role", + "head_cpu": "1", + "head_mem": "1Gi", + "worker_cpu": "2", + "worker_mem": "2Gi", + }, + ), + ( + "Partial config with some values", + { + "storage_uri": "cache://partial_schema/partial_table", + "train_job_name": "partial-job", + "cache_image": "partial-image:latest", + "iam_role": "arn:aws:iam::123456789012:role/partial-role", + "head_cpu": "2", + "worker_cpu": "4", + "metadata_loc": "s3://partial-bucket/metadata", + }, + { + "storage_uri": "cache://partial_schema/partial_table", + "train_job_name": "partial-job", + "cache_image": "partial-image:latest", + "cluster_size": "3", + "metadata_loc": "s3://partial-bucket/metadata", + "iam_role": "arn:aws:iam::123456789012:role/partial-role", + "head_cpu": "2", + "head_mem": "1Gi", + "worker_cpu": "4", + "worker_mem": "2Gi", + }, + ), + ], +) +def test_load_config(test_name, test_config, expected): + """Test config loading with different configurations""" + print(f"Running test: {test_name}") + + cache_initializer_instance = CacheInitializer() + + with patch.object(utils, "get_config_from_env", return_value=test_config): + cache_initializer_instance.load_config() + assert cache_initializer_instance.config.__dict__ == expected + + print("Test execution completed") + + +@pytest.mark.parametrize( + "test_name, test_case", + [ + ( + "Full configuration with all substitutions", + { + "config": { + "storage_uri": "cache://test_schema/test_table", + "train_job_name": "full-job", + "cache_image": "custom-cache:v1.0", + "cluster_size": "5", + "metadata_loc": "s3://test-bucket/metadata", + "iam_role": "arn:aws:iam::123456789012:role/test-role", + "head_cpu": "4", + "head_mem": "8Gi", + "worker_cpu": "8", + "worker_mem": "16Gi", + }, + "expected_train_job_name": "full-job", + }, + ), + ( + "Default values with minimal configuration", + { + "config": { + "storage_uri": "cache://minimal_test_schema/minimal_test_table", + "train_job_name": "minimal-job", + "cache_image": "test-image:latest", + "iam_role": "arn:aws:iam::123456789012:role/test-role", + "metadata_loc": "s3://minimal-test-bucket/metadata", + }, + "expected_train_job_name": "minimal-job", + }, + ), + ( + "Mixed configuration with some defaults", + { + "config": { + "storage_uri": "cache://mixed_schema/mixed_table", + "train_job_name": "mixed-job", + "cache_image": "mixed-image:v2.0", + "iam_role": "arn:aws:iam::987654321098:role/mixed-role", + "head_cpu": "6", + "worker_mem": "32Gi", + "metadata_loc": "s3://mixed-bucket/data", + }, + "expected_train_job_name": "mixed-job", + }, + ), + ( + "Minimal config uses defaults for optional fields", + { + "config": { + "storage_uri": "cache://required_schema/required_table", + "train_job_name": "required-job", + "cache_image": "test-image:required", + "iam_role": "arn:aws:iam::123456789012:role/required", + "metadata_loc": "s3://required-bucket/metadata", + }, + "expected_train_job_name": "required-job", + }, + ), + ], +) +def test_download_dataset(test_name, test_case): + """Test cache cluster creation with different configurations""" + + print(f"Running test: {test_name}") + + cache_initializer_instance = CacheInitializer() + + # Use proper load_config instead of mocking config directly + with patch.object(utils, "get_config_from_env", return_value=test_case["config"]): + cache_initializer_instance.load_config() + + with patch( + "pkg.initializers.dataset.cache.get_namespace", return_value="test-namespace" + ), patch("pkg.initializers.dataset.cache.config") as mock_config, patch( + "pkg.initializers.dataset.cache.client" + ) as mock_client: + + # Setup mocks for Kubernetes client + mock_api_client = MagicMock() + mock_core_v1 = MagicMock() + mock_custom_api = MagicMock() + + mock_client.ApiClient.return_value = mock_api_client + mock_client.CoreV1Api.return_value = mock_core_v1 + mock_client.CustomObjectsApi.return_value = mock_custom_api + + # Mock training job response + mock_training_job = { + "apiVersion": "trainer.kubeflow.org/v1alpha1", + "kind": "TrainJob", + "metadata": { + "name": test_case["expected_train_job_name"], + "uid": "test-uid", + }, + } + + # Mock LeaderWorkerSet status response (ready state) + mock_lws_ready = { + "status": {"conditions": [{"type": "Available", "status": "True"}]} + } + + # Set side_effect to return training job first, then ready LWS status + mock_custom_api.get_namespaced_custom_object.side_effect = [ + mock_training_job, # First call for training job + mock_lws_ready, # Second call for LWS status check + ] + + # Execute cache cluster creation + cache_initializer_instance.download_dataset() + + # Verify Kubernetes client calls were made + mock_config.load_incluster_config.assert_called_once() + mock_client.ApiClient.assert_called_once() + mock_client.CoreV1Api.assert_called_once_with(mock_api_client) + mock_client.CustomObjectsApi.assert_called_once_with(mock_api_client) + + print("Test execution completed") diff --git a/pkg/initializers/types/types.py b/pkg/initializers/types/types.py index 81e04a908d..38d3258dca 100644 --- a/pkg/initializers/types/types.py +++ b/pkg/initializers/types/types.py @@ -15,3 +15,18 @@ class HuggingFaceDatasetInitializer: class HuggingFaceModelInitializer: storage_uri: str access_token: Optional[str] = None + + +# Configuration for the cache dataset initializer. +@dataclass +class CacheDatasetInitializer: + storage_uri: str + train_job_name: str + cache_image: str + iam_role: str + metadata_loc: str + cluster_size: str = "3" + head_cpu: str = "1" + head_mem: str = "1Gi" + worker_cpu: str = "2" + worker_mem: str = "2Gi" diff --git a/pkg/initializers/utils/utils.py b/pkg/initializers/utils/utils.py index c8afa873e7..c177afaab1 100644 --- a/pkg/initializers/utils/utils.py +++ b/pkg/initializers/utils/utils.py @@ -5,6 +5,7 @@ STORAGE_URI_ENV = "STORAGE_URI" HF_SCHEME = "hf" +CACHE_SCHEME = "cache" # The default path to the users' workspace. # TODO (andreyvelich): Discuss how to keep this path is sync with Kubeflow SDK constants.