diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index a45f192c28b..144e2b02040 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -364,6 +364,8 @@ KUBERNETES_ANNOTATIONS = from_conf("KUBERNETES_ANNOTATIONS", "") # Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd) KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia") +# Default GPU type to use by K8S jobs created by Metaflow +KUBERNETES_GPU_TYPE = from_conf("KUBERNETES_GPU_TYPE", "gpu") # Default container image for K8S KUBERNETES_CONTAINER_IMAGE = from_conf( "KUBERNETES_CONTAINER_IMAGE", DEFAULT_CONTAINER_IMAGE diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 304fa9f3bd9..970453dda0e 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -440,8 +440,8 @@ def _to_job(self, node): limits={ **qos_limits, **{ - "%s.com/gpu".lower() - % k8s_deco.attributes["gpu_vendor"]: str(k8s_deco.attributes["gpu"]) + ("%s.com/%s".lower() + % (k8s_deco.attributes["gpu_vendor"], k8s_deco.attributes["gpu_type"])): str(k8s_deco.attributes["gpu"]) for k in [0] # Don't set GPU limits if gpu isn't specified. if k8s_deco.attributes["gpu"] is not None diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index af4a510cb5b..a37b146bf15 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -2033,6 +2033,7 @@ def _container_templates(self): disk=str(resources["disk"]), gpu=resources["gpu"], gpu_vendor=str(resources["gpu_vendor"]), + gpu_type=str(resources["gpu_type"]), tolerations=resources["tolerations"], use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, @@ -2248,10 +2249,13 @@ def _container_templates(self): limits={ **qos_limits, **{ - "%s.com/gpu".lower() - % resources["gpu_vendor"]: str( - resources["gpu"] - ) + ( + "%s.com/%s".lower() + % ( + resources["gpu_vendor"], + resources["gpu_type"], + ) + ): str(resources["gpu"]) for k in [0] if resources["gpu"] is not None }, diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 069c63ef211..c655a41bb2a 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -17,6 +17,7 @@ KUBERNETES_DISK, KUBERNETES_FETCH_EC2_METADATA, KUBERNETES_GPU_VENDOR, + KUBERNETES_GPU_TYPE, KUBERNETES_IMAGE_PULL_POLICY, KUBERNETES_MEMORY, KUBERNETES_LABELS, @@ -90,6 +91,8 @@ class KubernetesDecorator(StepDecorator): the scheduled node should not have GPUs. gpu_vendor : str, default KUBERNETES_GPU_VENDOR The vendor of the GPUs to be used for this step. + gpu_type : str , optional, default KUBERNETES_GPU_TYPE + The type of the GPUs to be used for this step. tolerations : List[str], default [] The default is extracted from METAFLOW_KUBERNETES_TOLERATIONS. Kubernetes tolerations to use when launching pod in Kubernetes. @@ -145,6 +148,7 @@ class KubernetesDecorator(StepDecorator): "namespace": None, "gpu": None, # value of 0 implies that the scheduled node should not have GPUs "gpu_vendor": None, + "gpu_type": None, "tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"}, # {"key": "foo", "operator": "Equal", "value": "bar"}] "labels": None, # e.g. {"test-label": "value", "another-label":"value2"} @@ -181,6 +185,8 @@ def init(self): self.attributes["gpu_vendor"] = KUBERNETES_GPU_VENDOR if not self.attributes["node_selector"] and KUBERNETES_NODE_SELECTOR: self.attributes["node_selector"] = KUBERNETES_NODE_SELECTOR + if not self.attributes["gpu_type"]: + self.attributes["gpu_type"] = KUBERNETES_GPU_TYPE if not self.attributes["tolerations"] and KUBERNETES_TOLERATIONS: self.attributes["tolerations"] = json.loads(KUBERNETES_TOLERATIONS) if ( diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index dc115c5758e..c56cc406972 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -171,10 +171,8 @@ def create_job_spec(self): limits={ **qos_limits, **{ - "%s.com/gpu".lower() - % self._kwargs["gpu_vendor"]: str( - self._kwargs["gpu"] - ) + ("%s.com/%s".lower() + % (self._kwargs["gpu_vendor"], self._kwargs["gpu_type"])): str(self._kwargs["gpu"]) for k in [0] # Don't set GPU limits if gpu isn't specified. if self._kwargs["gpu"] is not None diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index 3f9eda389f2..0e33bdf7404 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -670,10 +670,9 @@ def dump(self): limits={ **qos_limits, **{ - "%s.com/gpu".lower() - % self._kwargs["gpu_vendor"]: str( - self._kwargs["gpu"] - ) + ("%s.com/%s".lower() + % (self._kwargs["gpu_vendor"], self._kwargs["gpu_type"])): str( + self._kwargs["gpu"]) for k in [0] # Don't set GPU limits if gpu isn't specified. if self._kwargs["gpu"] is not None