Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 100
extend-ignore = W503
extend-ignore = W503, E203
1 change: 1 addition & 0 deletions cmd/initializers/dataset/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
huggingface-hub>=0.27.0,<0.28
kubernetes>=27.2.0
5 changes: 5 additions & 0 deletions pkg/initializers/dataset/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import urlparse

import pkg.initializers.utils.utils as utils
from pkg.initializers.dataset.cache import CacheInitializer
from pkg.initializers.dataset.huggingface import HuggingFace

logging.basicConfig(
Expand All @@ -27,6 +28,10 @@ def main():
hf = HuggingFace()
hf.load_config()
hf.download_dataset()
case utils.CACHE_SCHEME:
cache = CacheInitializer()
cache.load_config()
cache.download_dataset()
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception
Expand Down
285 changes: 285 additions & 0 deletions pkg/initializers/dataset/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import logging
import time

from kubernetes import client, config
from kubernetes.client.rest import ApiException
from kubernetes.dynamic.exceptions import ConflictError

import pkg.initializers.types.types as types
import pkg.initializers.utils.utils as utils

logging.basicConfig(
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.INFO,
)


def get_namespace() -> str:
"""Get the current namespace from the service account token."""
try:
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") as f:
return f.readline().strip()
except FileNotFoundError:
logging.warning(
"Service account namespace file not found, using 'default' namespace"
)
return "default"


class CacheInitializer(utils.DatasetProvider):

def load_config(self):
config_dict = utils.get_config_from_env(types.CacheDatasetInitializer)
self.config = types.CacheDatasetInitializer(**config_dict)

# Parse schema_name and table_name from storage_uri
# Format: cache://<SCHEMA_NAME>/<TABLE_NAME>
uri_path = self.config.storage_uri[len("cache://") :]
parts = uri_path.split("/")
self.schema_name = parts[0]
self.table_name = parts[1]

def download_dataset(self):
"""Bootstrap cache cluster with dataset"""
logging.info(
f"Cache initializer called with storage URI: {self.config.storage_uri}"
)

train_job_name = self.config.train_job_name
cache_image = self.config.cache_image
cluster_size = int(self.config.cluster_size)
iam_role = self.config.iam_role
head_cpu = self.config.head_cpu
head_mem = self.config.head_mem
worker_cpu = self.config.worker_cpu
worker_mem = self.config.worker_mem
namespace = get_namespace()
metadata_loc = self.config.metadata_loc
table_name = self.table_name
schema_name = self.schema_name

# Load Kubernetes configuration
config.load_incluster_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too deep into the design of this, so apologies for the out-of-context comment, but my first reaction is should all this be part of the control plane and not the runtime?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, ideally we should move it to operator, we just didn't get chance to work on this.
@akshaychitneni Maybe as a workaround before building a cache controller, we can use trainer-controller-manager to create LWS with the appropriate spec (e.g. the cache plugin can be activated when storageURI sets as follows: cache://database/table)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was the initial plan to add as a plugin to trainer. As we intend to make leverage its own operator we haven't pursed that path. I think we can revisit this approach.


api_client = client.ApiClient()
core_v1 = client.CoreV1Api(api_client)
custom_api = client.CustomObjectsApi(api_client)

# Get TrainJob for owner reference
try:
training_job = custom_api.get_namespaced_custom_object(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What gives the permissions to the TrainJob initializer to perform those requests to the API server?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runtime should be configured with initializer having a serviceAccount with relevant permissions. We plan to document it.

group="trainer.kubeflow.org",
version="v1alpha1",
plural="trainjobs",
namespace=namespace,
name=train_job_name,
)
logging.info(f"TrainJob: {training_job}")

# Create owner reference dictionary
logging.info(
f"Creating owner reference from TrainJob: {training_job['metadata']['name']}"
)

owner_ref_dict = {
"apiVersion": training_job["apiVersion"],
"kind": training_job["kind"],
"name": training_job["metadata"]["name"],
"uid": training_job["metadata"]["uid"],
"controller": True,
"blockOwnerDeletion": True,
}

logging.info(
f"Owner reference created with apiVersion='{training_job['apiVersion']}', "
f"kind='{training_job['kind']}')"
)
except ApiException as e:
logging.error(f"Failed to get TrainJob {train_job_name}: {e}")
return

try:
# Create ServiceAccount
service_account = client.V1ServiceAccount(
metadata=client.V1ObjectMeta(
name=f"{train_job_name}-cache",
namespace=namespace,
annotations={
"eks.amazonaws.com/sts-regional-endpoints": "true",
"eks.amazonaws.com/role-arn": iam_role,
},
Comment on lines +108 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that be made configurable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our initial implementation only support s3 via iam. I think it is good make this configurable once we support additional providers

owner_references=[owner_ref_dict],
)
)

try:
core_v1.create_namespaced_service_account(
namespace=namespace, body=service_account
)
logging.info(f"Created ServiceAccount {service_account.metadata.name}")
except ApiException as e:
if e.status == 409:
logging.info(
f"ServiceAccount {service_account.metadata.name} "
f"already exists, skipping creation"
)
else:
raise e

# Prepare environment variables
env_vars = []
if metadata_loc:
env_vars.append({"name": "METADATA_LOC", "value": metadata_loc})
if table_name:
env_vars.append({"name": "TABLE_NAME", "value": table_name})
if schema_name:
env_vars.append({"name": "SCHEMA_NAME", "value": schema_name})

# Create LeaderWorkerSet
lws_body = {
"apiVersion": "leaderworkerset.x-k8s.io/v1",
"kind": "LeaderWorkerSet",
"metadata": {
"name": f"{train_job_name}-cache",
"namespace": namespace,
"ownerReferences": [owner_ref_dict],
},
"spec": {
"replicas": 1,
"leaderWorkerTemplate": {
"size": cluster_size,
"leaderTemplate": {
"metadata": {
"labels": {"app": f"{train_job_name}-cache-head"}
},
"spec": {
"serviceAccountName": service_account.metadata.name,
"containers": [
{
"name": "head",
"image": cache_image,
"command": ["head"],
"args": ["0.0.0.0", "50051"],
"resources": {
"limits": {
"cpu": head_cpu,
"memory": head_mem,
},
"requests": {
"cpu": head_cpu,
"memory": head_mem,
},
},
"env": env_vars,
"ports": [{"containerPort": 50051}],
}
],
},
},
"workerTemplate": {
"spec": {
"serviceAccountName": f"{train_job_name}-cache",
"containers": [
{
"name": "worker",
"image": cache_image,
"command": ["worker"],
"args": ["0.0.0.0", "50051"],
"resources": {
"limits": {
"cpu": worker_cpu,
"memory": worker_mem,
},
"requests": {
"cpu": worker_cpu,
"memory": worker_mem,
},
},
"env": env_vars,
"ports": [{"containerPort": 50051}],
}
],
}
},
},
},
}

# Create LeaderWorkerSet
custom_api.create_namespaced_custom_object(
group="leaderworkerset.x-k8s.io",
version="v1",
namespace=namespace,
plural="leaderworkersets",
body=lws_body,
)
logging.info(f"Created LeaderWorkerSet {lws_body['metadata']['name']}")

# Create Service
service = client.V1Service(
metadata=client.V1ObjectMeta(
name=f"{train_job_name}-cache-service",
namespace=namespace,
owner_references=[owner_ref_dict],
),
spec=client.V1ServiceSpec(
selector={"app": f"{train_job_name}-cache-head"},
ports=[
client.V1ServicePort(
protocol="TCP", port=50051, target_port=50051
)
],
),
)

try:
core_v1.create_namespaced_service(namespace=namespace, body=service)
logging.info(f"Created Service {service.metadata.name}")
except ApiException as e:
if e is ConflictError:
logging.info(
f"Service {service.metadata.name} already exists, "
f"skipping creation"
)
else:
raise e

# Wait for LeaderWorkerSet to become ready
# TODO:// refactor to use watch API
while True:
try:
lws = custom_api.get_namespaced_custom_object(
group="leaderworkerset.x-k8s.io",
version="v1",
plural="leaderworkersets",
name=lws_body["metadata"]["name"],
namespace=namespace,
)

conditions = lws.get("status", {}).get("conditions", [])
if any(
c["type"] == "Available" and c["status"] == "True"
for c in conditions
):
logging.info(
f"LeaderWorkerSet {lws_body['metadata']['name']} is ready"
)
break

time.sleep(5)
except ApiException as e:
raise e

except ApiException as e:
logging.error(f"Cache cluster creation failed: {e}")
# Cleanup on failure
try:
core_v1.delete_namespaced_service_account(
name=f"{train_job_name}-cache", namespace=namespace
)
except Exception as cleanup_error:
logging.error(f"Error cleaning up ServiceAccount: {cleanup_error}")
return

logging.info("Cache cluster creation completed")
Loading
Loading