diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 17397ed2fa..7ee74a0830 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -40,6 +40,8 @@ title: Integrate a library - local: guides/webhooks title: Webhooks + - local: guides/jobs + title: Jobs - title: 'Conceptual guides' sections: - local: concepts/git_vs_http @@ -92,3 +94,5 @@ title: Strict dataclasses - local: package_reference/oauth title: OAuth + - local: package_reference/jobs + title: Jobs diff --git a/docs/source/en/guides/cli.md b/docs/source/en/guides/cli.md index 481cbf4c79..a7d2844dc4 100644 --- a/docs/source/en/guides/cli.md +++ b/docs/source/en/guides/cli.md @@ -604,3 +604,147 @@ Copy-and-paste the text below in your GitHub issue. - HF_HUB_ETAG_TIMEOUT: 10 - HF_HUB_DOWNLOAD_TIMEOUT: 10 ``` + +## huggingface-cli jobs + +Run compute jobs on Hugging Face infrastructure with a familiar Docker-like interface. + +`huggingface-cli jobs` is a command-line tool that lets you run anything on Hugging Face's infrastructure (including GPUs and TPUs!) with simple commands. Think `docker run`, but for running code on A100s. + +```bash +# Directly run Python code +>>> huggingface-cli jobs run python:3.12 python -c "print('Hello from the cloud!')" + +# Use GPUs without any setup +>>> huggingface-cli jobs run --flavor a10g-small pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel \ +... python -c "import torch; print(torch.cuda.get_device_name())" + +# Run in an organization account +>>> huggingface-cli jobs run --namespace my-org-name python:3.12 python -c "print('Running in an org account')" + +# Run from Hugging Face Spaces +>>> huggingface-cli jobs run hf.co/spaces/lhoestq/duckdb duckdb -c "select 'hello world'" + +# Run a Python script with `uv` (experimental) +>>> huggingface-cli jobs uv run my_script.py +``` + +### ✨ Key Features + +- 🐳 **Docker-like CLI**: Familiar commands (`run`, `ps`, `logs`, `inspect`) to run and manage jobs +- 🔥 **Any Hardware**: From CPUs to A100 GPUs and TPU pods - switch with a simple flag +- 📦 **Run Anything**: Use Docker images, HF Spaces, or your custom containers +- 🔐 **Simple Auth**: Just use your HF token +- 📊 **Live Monitoring**: Stream logs in real-time, just like running locally +- 💰 **Pay-as-you-go**: Only pay for the seconds you use + +### Quick Start + +#### 1. Run your first job + +```bash +# Run a simple Python script +>>> huggingface-cli jobs run python:3.12 python -c "print('Hello from HF compute!')" +``` + +This command runs the job and shows the logs. You can pass `--detach` to run the Job in the background and only print the Job ID. + +#### 2. Check job status + +```bash +# List your running jobs +>>> huggingface-cli jobs ps + +# Inspect the status of a job +>>> huggingface-cli jobs inspect + +# View logs from a job +>>> huggingface-cli jobs logs + +# Cancel a job +>>> huggingface-cli jobs cancel +``` + +#### 3. Run on GPU + +You can also run jobs on GPUs or TPUs with the `--flavor` option. For example, to run a PyTorch job on an A10G GPU: + +```bash +# Use an A10G GPU to check PyTorch CUDA +>>> huggingface-cli jobs run --flavor a10g-small pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel \ +... python -c "import torch; print(f"This code ran with the following GPU: {torch.cuda.get_device_name()}")" +``` + +Running this will show the following output! + +```bash +This code ran with the following GPU: NVIDIA A10G +``` + +That's it! You're now running code on Hugging Face's infrastructure. + +### Common Use Cases + +- **Model Training**: Fine-tune or train models on GPUs (T4, A10G, A100) without managing infrastructure +- **Synthetic Data Generation**: Generate large-scale datasets using LLMs on powerful hardware +- **Data Processing**: Process massive datasets with high-CPU configurations for parallel workloads +- **Batch Inference**: Run offline inference on thousands of samples using optimized GPU setups +- **Experiments & Benchmarks**: Run ML experiments on consistent hardware for reproducible results +- **Development & Debugging**: Test GPU code without local CUDA setup + +### Pass Environment variables and Secrets + +You can pass environment variables to your job using + +```bash +# Pass environment variables +>>> huggingface-cli jobs run -e FOO=foo -e BAR=bar python:3.12 python -c "import os; print(os.environ['FOO'], os.environ['BAR'])" +``` + +```bash +# Pass an environment from a local .env file +>>> huggingface-cli jobs run --env-file .env python:3.12 python -c "import os; print(os.environ['FOO'], os.environ['BAR'])" +``` + +```bash +# Pass secrets - they will be encrypted server side +>>> huggingface-cli jobs run -s MY_SECRET=psswrd python:3.12 python -c "import os; print(os.environ['MY_SECRET'])" +``` + +```bash +# Pass secrets from a local .env.secrets file - they will be encrypted server side +>>> huggingface-cli jobs run --secrets-file .env.secrets python:3.12 python -c "import os; print(os.environ['MY_SECRET'])" +``` + +### Hardware + +Available `--flavor` options: + +- CPU: `cpu-basic`, `cpu-upgrade` +- GPU: `t4-small`, `t4-medium`, `l4x1`, `l4x4`, `a10g-small`, `a10g-large`, `a10g-largex2`, `a10g-largex4`,`a100-large` +- TPU: `v5e-1x1`, `v5e-2x2`, `v5e-2x4` + +(updated in 07/2025 from Hugging Face [suggested_hardware docs](https://huggingface.co/docs/hub/en/spaces-config-reference)) + +### UV Scripts (Experimental) + +Run UV scripts (Python scripts with inline dependencies) on HF infrastructure: + +```bash +# Run a UV script (creates temporary repo) +>>> huggingface-cli jobs uv run my_script.py + +# Run with persistent repo +>>> huggingface-cli jobs uv run my_script.py --repo my-uv-scripts + +# Run with GPU +>>> huggingface-cli jobs uv run ml_training.py --flavor gpu-t4-small + +# Pass arguments to script +>>> huggingface-cli jobs uv run process.py input.csv output.parquet --repo data-scripts + +# Run a script directly from a URL +>>> huggingface-cli jobs uv run https://huggingface.co/datasets/username/scripts/resolve/main/example.py +``` + +UV scripts are Python scripts that include their dependencies directly in the file using a special comment syntax. This makes them perfect for self-contained tasks that don't require complex project setups. Learn more about UV scripts in the [UV documentation](https://docs.astral.sh/uv/guides/scripts/). diff --git a/docs/source/en/guides/jobs.md b/docs/source/en/guides/jobs.md new file mode 100644 index 0000000000..a4ce0b4cb3 --- /dev/null +++ b/docs/source/en/guides/jobs.md @@ -0,0 +1,220 @@ + +# Run and manage Jobs + +The Hugging Face Hub provides compute for AI and data workflows via Jobs. +A job runs on Hugging Face infrastructure and are defined with a command to run (e.g. a python command), a Docker Image from Hugging Face Spaces or Docker Hub, and a hardware flavor (CPU, GPU, TPU). This guide will show you how to interact with Jobs on the Hub, especially: + +- Run a job. +- Check job status. +- Select the hardware. +- Configure environment variables and secrets. +- Run UV scripts. + +If you want to run and manage a job on the Hub, your machine must be logged in. If you are not, please refer to +[this section](../quick-start#authentication). In the rest of this guide, we will assume that your machine is logged in. + +## Run a Job + +Run compute Jobs defined with a command and a Docker Image on Hugging Face infrastructure (including GPUs and TPUs). + +You can only manage Jobs that you own (under your username namespace) or from organizations in which you have write permissions. +This feature is pay-as-you-go: you only pay for the seconds you use. + +[`run_job`] lets you run any command on Hugging Face's infrastructure: + +```python +# Directly run Python code +>>> from huggingface_hub import run_job +>>> run_job( +... image="python:3.12", +... command=["python", "-c", "print('Hello from the cloud!')"], +... ) + +# Use GPUs without any setup +>>> run_job( +... image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel", +... command=["python", "-c", "import torch; print(torch.cuda.get_device_name())"], +... flavor="a10g-small", +... ) + +# Run in an organization account +>>> run_job( +... image="python:3.12", +... command=["python", "-c", "print('Running in an org account')"], +... namespace="my-org-name", +... ) + +# Run from Hugging Face Spaces +>>> run_job( +... image="hf.co/spaces/lhoestq/duckdb", +... command=["duckdb", "-c", "select 'hello world'"], +... ) + +# Run a Python script with `uv` (experimental) +>>> from huggingface_hub import run_uv_job +>>> run_uv_job("my_script.py") +``` + + + +Use [huggingface-cli jobs](./cli#huggingface-cli-jobs) to run jobs in the command line. + + + +[`run_job`] returns the [`JobInfo`] which has the URL of the Job on Hugging Face, where you can see the Job status and the logs. +Save the Job ID from [`JobInfo`] to manage the job: + +```python +>>> from huggingface_hub import run_job +>>> job = run_job( +... image="python:3.12", +... command=["python", "-c", "print('Hello from the cloud!')"] +... ) +>>> job.url +https://huggingface.co/jobs/lhoestq/687f911eaea852de79c4a50a +>>> job.id +687f911eaea852de79c4a50a +``` + +Jobs run in the background. The next section guides you through [`inspect_job`] to know a jobs' status and [`fetch_job_logs`] to view the logs. + +## Check Job status + +```python +# List your jobs +>>> from huggingface_hub import list_jobs +>>> jobs = list_jobs() +>>> jobs[0] +JobInfo(id='687f911eaea852de79c4a50a', created_at=datetime.datetime(2025, 7, 22, 13, 24, 46, 909000, tzinfo=datetime.timezone.utc), docker_image='python:3.12', space_id=None, command=['python', '-c', "print('Hello from the cloud!')"], arguments=[], environment={}, secrets={}, flavor='cpu-basic', status=JobStatus(stage='COMPLETED', message=None), owner=JobOwner(id='5e9ecfc04957053f60648a3e', name='lhoestq'), endpoint='https://huggingface.co', url='https://huggingface.co/jobs/lhoestq/687f911eaea852de79c4a50a') + +# List your running jobs +>>> running_jobs = [job for job in list_jobs() if job.status.stage == "RUNNING"] + +# Inspect the status of a job +>>> from huggingface_hub import inspect_job +>>> inspect_job(job_id=job_id) +JobInfo(id='687f911eaea852de79c4a50a', created_at=datetime.datetime(2025, 7, 22, 13, 24, 46, 909000, tzinfo=datetime.timezone.utc), docker_image='python:3.12', space_id=None, command=['python', '-c', "print('Hello from the cloud!')"], arguments=[], environment={}, secrets={}, flavor='cpu-basic', status=JobStatus(stage='COMPLETED', message=None), owner=JobOwner(id='5e9ecfc04957053f60648a3e', name='lhoestq'), endpoint='https://huggingface.co', url='https://huggingface.co/jobs/lhoestq/687f911eaea852de79c4a50a') + +# View logs from a job +>>> from huggingface_hub import fetch_job_logs +>>> for log in fetch_job_logs(job_id=job_id): +... print(log) +Hello from the cloud! + +# Cancel a job +>>> from huggingface_hub import cancel_job +>>> cancel_job(job_id=job_id) +``` + +Check the status of multiple jobs to know when they're all finished using a loop and [`inspect_job`]: + +```python +# Run multiple jobs in parallel and wait for their completions +>>> import time +>>> from huggingface_hub import inspect_job, run_job +>>> jobs = [run_job(image=image, command=command) for command in commands] +>>> for job in jobs: +... while inspect_job(job_id=job.id).status.stage not in ("COMPLETED", "ERROR"): +... time.sleep(10) +``` + +## Select the hardware + +There are numerous cases where running Jobs on GPUs are useful: + +- **Model Training**: Fine-tune or train models on GPUs (T4, A10G, A100) without managing infrastructure +- **Synthetic Data Generation**: Generate large-scale datasets using LLMs on powerful hardware +- **Data Processing**: Process massive datasets with high-CPU configurations for parallel workloads +- **Batch Inference**: Run offline inference on thousands of samples using optimized GPU setups +- **Experiments & Benchmarks**: Run ML experiments on consistent hardware for reproducible results +- **Development & Debugging**: Test GPU code without local CUDA setup + +Run jobs on GPUs or TPUs with the `flavor` argument. For example, to run a PyTorch job on an A10G GPU: + +```python +# Use an A10G GPU to check PyTorch CUDA +>>> from huggingface_hub import run_job +>>> run_job( +... image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel", +... command=["python", "-c", "import torch; print(f'This code ran with the following GPU: {torch.cuda.get_device_name()}')"], +... flavor="a10g-small", +... ) +``` + +Running this will show the following output! + +```bash +This code ran with the following GPU: NVIDIA A10G +``` + +Use this to run a fine tuning script like [trl/scripts/sft.py](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) with UV: + +```python +>>> from huggingface_hub import run_uv_job +>>> run_uv_job( +... "sft.py", +... script_args=["--model_name_or_path", "Qwen/Qwen2-0.5B", ...], +... dependencies=["trl"], +... env={"HF_TOKEN": ...}, +... flavor="a10g-small", +... ) +``` + +Available `flavor` options: + +- CPU: `cpu-basic`, `cpu-upgrade` +- GPU: `t4-small`, `t4-medium`, `l4x1`, `l4x4`, `a10g-small`, `a10g-large`, `a10g-largex2`, `a10g-largex4`,`a100-large` +- TPU: `v5e-1x1`, `v5e-2x2`, `v5e-2x4` + +(updated in 07/2025 from Hugging Face [suggested_hardware docs](https://huggingface.co/docs/hub/en/spaces-config-reference)) + +That's it! You're now running code on Hugging Face's infrastructure. + +## Pass Environment variables and Secrets + +You can pass environment variables to your job using `env` and `secrets`: + +```python +# Pass environment variables +>>> from huggingface_hub import run_job +>>> run_job( +... image="python:3.12", +... command=["python", "-c", "import os; print(os.environ['FOO'], os.environ['BAR'])"], +... env={"FOO": "foo", "BAR": "bar"}, +... ) +``` + + +```python +# Pass secrets - they will be encrypted server side +>>> from huggingface_hub import run_job +>>> run_job( +... image="python:3.12", +... command=["python", "-c", "import os; print(os.environ['MY_SECRET'])"], +... secrets={"MY_SECRET": "psswrd"}, +... ) +``` + + +### UV Scripts (Experimental) + +Run UV scripts (Python scripts with inline dependencies) on HF infrastructure: + +```python +# Run a UV script (creates temporary repo) +>>> from huggingface_hub import run_uv_job +>>> run_uv_job("my_script.py") + +# Run with GPU +>>> run_uv_job("ml_training.py", flavor="gpu-t4-small") + +# Run with dependencies +>>> run_uv_job("inference.py", dependencies=["transformers", "torch"]) + +# Run a script directly from a URL +>>> run_uv_job("https://huggingface.co/datasets/username/scripts/resolve/main/example.py") +``` + +UV scripts are Python scripts that include their dependencies directly in the file using a special comment syntax. This makes them perfect for self-contained tasks that don't require complex project setups. Learn more about UV scripts in the [UV documentation](https://docs.astral.sh/uv/guides/scripts/). diff --git a/docs/source/en/guides/overview.md b/docs/source/en/guides/overview.md index fd0c8c417f..84a846a997 100644 --- a/docs/source/en/guides/overview.md +++ b/docs/source/en/guides/overview.md @@ -127,5 +127,14 @@ Take a look at these guides to learn how to use huggingface_hub to solve real-wo

+ +
+ Jobs +

+ How to run and manage compute Jobs on Hugging Face infrastructure and select the hardware? +

+
+ diff --git a/docs/source/en/package_reference/jobs.md b/docs/source/en/package_reference/jobs.md new file mode 100644 index 0000000000..eca90bc3ad --- /dev/null +++ b/docs/source/en/package_reference/jobs.md @@ -0,0 +1,33 @@ + + +# Jobs + +Check the [`HfApi`] documentation page for the reference of methods to manage your Jobs on the Hub. + +- Run a Job: [`run_job`] +- Fetch logs: [`fetch_job_logs`] +- Inspect Job: [`inspect_job`] +- List Jobs: [`list_jobs`] +- Cancel Job: [`cancel_job`] +- Run a UV Job: [`run_uv_job`] + +## Data structures + +### JobInfo + +[[autodoc]] JobInfo + +### JobOwner + +[[autodoc]] JobOwner + + +### JobStage + +[[autodoc]] JobStage + +### JobStatus + +[[autodoc]] JobStatus diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 49cd48ad41..c58b3e4aca 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -62,6 +62,12 @@ "InferenceEndpointTimeoutError", "InferenceEndpointType", ], + "_jobs_api": [ + "JobInfo", + "JobOwner", + "JobStage", + "JobStatus", + ], "_login": [ "auth_list", "auth_switch", @@ -165,6 +171,7 @@ "add_space_variable", "auth_check", "cancel_access_request", + "cancel_job", "change_discussion_status", "comment_discussion", "create_branch", @@ -194,6 +201,7 @@ "duplicate_space", "edit_discussion_comment", "enable_webhook", + "fetch_job_logs", "file_exists", "get_collection", "get_dataset_tags", @@ -210,11 +218,13 @@ "get_user_overview", "get_webhook", "grant_access", + "inspect_job", "list_accepted_access_requests", "list_collections", "list_datasets", "list_inference_catalog", "list_inference_endpoints", + "list_jobs", "list_lfs_files", "list_liked_repos", "list_models", @@ -251,6 +261,8 @@ "resume_inference_endpoint", "revision_exists", "run_as_future", + "run_job", + "run_uv_job", "scale_to_zero_inference_endpoint", "set_space_sleep_time", "space_info", @@ -656,6 +668,10 @@ "InferenceEndpointTimeoutError", "InferenceEndpointType", "InferenceTimeoutError", + "JobInfo", + "JobOwner", + "JobStage", + "JobStatus", "KerasModelHubMixin", "MCPClient", "ModelCard", @@ -792,6 +808,7 @@ "auth_switch", "cached_assets_path", "cancel_access_request", + "cancel_job", "change_discussion_status", "comment_discussion", "configure_http_backend", @@ -825,6 +842,7 @@ "enable_webhook", "export_entries_as_dduf", "export_folder_as_dduf", + "fetch_job_logs", "file_exists", "from_pretrained_fastai", "from_pretrained_keras", @@ -851,12 +869,14 @@ "grant_access", "hf_hub_download", "hf_hub_url", + "inspect_job", "interpreter_login", "list_accepted_access_requests", "list_collections", "list_datasets", "list_inference_catalog", "list_inference_endpoints", + "list_jobs", "list_lfs_files", "list_liked_repos", "list_models", @@ -907,6 +927,8 @@ "resume_inference_endpoint", "revision_exists", "run_as_future", + "run_job", + "run_uv_job", "save_pretrained_keras", "save_torch_model", "save_torch_state_dict", @@ -1044,6 +1066,12 @@ def __dir__(): InferenceEndpointTimeoutError, # noqa: F401 InferenceEndpointType, # noqa: F401 ) + from ._jobs_api import ( + JobInfo, # noqa: F401 + JobOwner, # noqa: F401 + JobStage, # noqa: F401 + JobStatus, # noqa: F401 + ) from ._login import ( auth_list, # noqa: F401 auth_switch, # noqa: F401 @@ -1143,6 +1171,7 @@ def __dir__(): add_space_variable, # noqa: F401 auth_check, # noqa: F401 cancel_access_request, # noqa: F401 + cancel_job, # noqa: F401 change_discussion_status, # noqa: F401 comment_discussion, # noqa: F401 create_branch, # noqa: F401 @@ -1172,6 +1201,7 @@ def __dir__(): duplicate_space, # noqa: F401 edit_discussion_comment, # noqa: F401 enable_webhook, # noqa: F401 + fetch_job_logs, # noqa: F401 file_exists, # noqa: F401 get_collection, # noqa: F401 get_dataset_tags, # noqa: F401 @@ -1188,11 +1218,13 @@ def __dir__(): get_user_overview, # noqa: F401 get_webhook, # noqa: F401 grant_access, # noqa: F401 + inspect_job, # noqa: F401 list_accepted_access_requests, # noqa: F401 list_collections, # noqa: F401 list_datasets, # noqa: F401 list_inference_catalog, # noqa: F401 list_inference_endpoints, # noqa: F401 + list_jobs, # noqa: F401 list_lfs_files, # noqa: F401 list_liked_repos, # noqa: F401 list_models, # noqa: F401 @@ -1229,6 +1261,8 @@ def __dir__(): resume_inference_endpoint, # noqa: F401 revision_exists, # noqa: F401 run_as_future, # noqa: F401 + run_job, # noqa: F401 + run_uv_job, # noqa: F401 scale_to_zero_inference_endpoint, # noqa: F401 set_space_sleep_time, # noqa: F401 space_info, # noqa: F401 diff --git a/src/huggingface_hub/_jobs_api.py b/src/huggingface_hub/_jobs_api.py new file mode 100644 index 0000000000..cdfed4f9dd --- /dev/null +++ b/src/huggingface_hub/_jobs_api.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2025-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from huggingface_hub import constants +from huggingface_hub._space_api import SpaceHardware +from huggingface_hub.utils._datetime import parse_datetime + + +class JobStage(str, Enum): + """ + Enumeration of possible stage of a Job on the Hub. + + Value can be compared to a string: + ```py + assert JobStage.COMPLETED == "COMPLETED" + ``` + + Taken from https://github.com/huggingface/moon-landing/blob/main/server/job_types/JobInfo.ts#L61 (private url). + """ + + # Copied from moon-landing > server > lib > Job.ts + COMPLETED = "COMPLETED" + CANCELED = "CANCELED" + ERROR = "ERROR" + DELETED = "DELETED" + RUNNING = "RUNNING" + + +@dataclass +class JobStatus: + stage: JobStage + message: Optional[str] + + def __init__(self, **kwargs) -> None: + self.stage = kwargs["stage"] + self.message = kwargs.get("message") + + +@dataclass +class JobOwner: + id: str + name: str + + +@dataclass +class JobInfo: + """ + Contains information about a Job. + + Args: + id (`str`): + Job ID. + created_at (`datetime` or `None`): + When the Job was created. + docker_image (`str` or `None`): + The Docker image from Docker Hub used for the Job. + Can be None if space_id is present instead. + space_id (`str` or `None`): + The Docker image from Hugging Face Spaces used for the Job. + Can be None if docker_image is present instead. + command (`List[str]` or `None`): + Command of the Job, e.g. `["python", "-c", "print('hello world')"]` + arguments (`List[str]` or `None`): + Arguments passed to the command + environment (`Dict[str]` or `None`): + Environment variables of the Job as a dictionary. + secrets (`Dict[str]` or `None`): + Secret environment variables of the Job (encrypted). + flavor (`str` or `None`): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + E.g. `"cpu-basic"`. + status: (`JobStatus` or `None`): + Status of the Job, e.g. `JobStatus(stage="RUNNING", message=None)` + See [`JobStage`] for possible stage values. + status: (`JobOwner` or `None`): + Owner of the Job, e.g. `JobOwner(id="5e9ecfc04957053f60648a3e", name="lhoestq")` + + Example: + + ```python + >>> from huggingface_hub import run_job + >>> job = run_job( + ... image="python:3.12", + ... command=["python", "-c", "print('Hello from the cloud!')"] + ... ) + >>> job + JobInfo(id='687fb701029421ae5549d998', created_at=datetime.datetime(2025, 7, 22, 16, 6, 25, 79000, tzinfo=datetime.timezone.utc), docker_image='python:3.12', space_id=None, command=['python', '-c', "print('Hello from the cloud!')"], arguments=[], environment={}, secrets={}, flavor='cpu-basic', status=JobStatus(stage='RUNNING', message=None), owner=JobOwner(id='5e9ecfc04957053f60648a3e', name='lhoestq'), endpoint='https://huggingface.co', url='https://huggingface.co/jobs/lhoestq/687fb701029421ae5549d998') + >>> job.id + '687fb701029421ae5549d998' + >>> job.url + 'https://huggingface.co/jobs/lhoestq/687fb701029421ae5549d998' + >>> job.status.stage + 'RUNNING' + ``` + """ + + id: str + created_at: Optional[datetime] + docker_image: Optional[str] + space_id: Optional[str] + command: Optional[List[str]] + arguments: Optional[List[str]] + environment: Optional[Dict[str, Any]] + secrets: Optional[Dict[str, Any]] + flavor: Optional[SpaceHardware] + status: Optional[JobStatus] + owner: Optional[JobOwner] + + # Inferred fields + endpoint: str + url: str + + def __init__(self, **kwargs) -> None: + self.id = kwargs["id"] + created_at = kwargs.get("createdAt") or kwargs.get("created_at") + self.created_at = parse_datetime(created_at) if created_at else None + self.docker_image = kwargs.get("dockerImage") or kwargs.get("docker_image") + self.space_id = kwargs.get("spaceId") or kwargs.get("space_id") + self.owner = JobOwner(**(kwargs["owner"] if isinstance(kwargs.get("owner"), dict) else {})) + self.command = kwargs.get("command") + self.arguments = kwargs.get("arguments") + self.environment = kwargs.get("environment") + self.secrets = kwargs.get("secrets") + self.flavor = kwargs.get("flavor") + self.status = JobStatus(**(kwargs["status"] if isinstance(kwargs.get("status"), dict) else {})) + + # Inferred fields + self.endpoint = kwargs.get("endpoint", constants.ENDPOINT) + self.url = f"{self.endpoint}/jobs/{self.owner.name}/{self.id}" diff --git a/src/huggingface_hub/commands/huggingface_cli.py b/src/huggingface_hub/commands/huggingface_cli.py index 4e30f305c2..35b4395229 100644 --- a/src/huggingface_hub/commands/huggingface_cli.py +++ b/src/huggingface_hub/commands/huggingface_cli.py @@ -17,6 +17,7 @@ from huggingface_hub.commands.delete_cache import DeleteCacheCommand from huggingface_hub.commands.download import DownloadCommand from huggingface_hub.commands.env import EnvironmentCommand +from huggingface_hub.commands.jobs import JobsCommands from huggingface_hub.commands.lfs import LfsCommands from huggingface_hub.commands.repo import RepoCommands from huggingface_hub.commands.repo_files import RepoFilesCommand @@ -44,6 +45,7 @@ def main(): DeleteCacheCommand.register_subcommand(commands_parser) TagCommands.register_subcommand(commands_parser) VersionCommand.register_subcommand(commands_parser) + JobsCommands.register_subcommand(commands_parser) # Experimental UploadLargeFolderCommand.register_subcommand(commands_parser) diff --git a/src/huggingface_hub/commands/jobs.py b/src/huggingface_hub/commands/jobs.py new file mode 100644 index 0000000000..9509458e9c --- /dev/null +++ b/src/huggingface_hub/commands/jobs.py @@ -0,0 +1,510 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with jobs on the Hugging Face Hub. + +Usage: + # run a job + huggingface-cli jobs run image command +""" + +import json +import os +import re +from argparse import Namespace, _SubParsersAction +from dataclasses import asdict +from pathlib import Path +from typing import Dict, List, Optional, Union + +import requests + +from huggingface_hub import HfApi, SpaceHardware +from huggingface_hub.utils import logging +from huggingface_hub.utils._dotenv import load_dotenv + +from . import BaseHuggingfaceCLICommand + + +logger = logging.get_logger(__name__) + +SUGGESTED_FLAVORS = [item.value for item in SpaceHardware if item.value != "zero-a10g"] + + +class JobsCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + jobs_parser = parser.add_parser("jobs", help="Commands to interact with your huggingface.co jobs.") + jobs_subparsers = jobs_parser.add_subparsers(help="huggingface.co jobs related commands") + + # Register commands + InspectCommand.register_subcommand(jobs_subparsers) + LogsCommand.register_subcommand(jobs_subparsers) + PsCommand.register_subcommand(jobs_subparsers) + RunCommand.register_subcommand(jobs_subparsers) + CancelCommand.register_subcommand(jobs_subparsers) + UvCommand.register_subcommand(jobs_subparsers) + + +class RunCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("run", help="Run a Job") + run_parser.add_argument("image", type=str, help="The Docker image to use.") + run_parser.add_argument("-e", "--env", action="append", help="Set environment variables.") + run_parser.add_argument("-s", "--secrets", action="append", help="Set secret environment variables.") + run_parser.add_argument("--env-file", type=str, help="Read in a file of environment variables.") + run_parser.add_argument("--secrets-file", type=str, help="Read in a file of secret environment variables.") + run_parser.add_argument( + "--flavor", + type=str, + help=f"Flavor for the hardware, as in HF Spaces. Defaults to `cpu-basic`. Possible values: {', '.join(SUGGESTED_FLAVORS)}.", + ) + run_parser.add_argument( + "--timeout", + type=str, + help="Max duration: int/float with s (seconds, default), m (minutes), h (hours) or d (days).", + ) + run_parser.add_argument( + "-d", + "--detach", + action="store_true", + help="Run the Job in the background and print the Job ID.", + ) + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the Job will be created. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", + type=str, + help="A User Access Token generated from https://huggingface.co/settings/tokens", + ) + run_parser.add_argument("command", nargs="...", help="The command to run.") + run_parser.set_defaults(func=RunCommand) + + def __init__(self, args: Namespace) -> None: + self.image: str = args.image + self.command: List[str] = args.command + self.env: dict[str, Optional[str]] = {} + if args.env_file: + self.env.update(load_dotenv(Path(args.env_file).read_text())) + for env_value in args.env or []: + self.env.update(load_dotenv(env_value)) + self.secrets: dict[str, Optional[str]] = {} + if args.secrets_file: + self.secrets.update(load_dotenv(Path(args.secrets_file).read_text())) + for secret in args.secrets or []: + self.secrets.update(load_dotenv(secret)) + self.flavor: Optional[SpaceHardware] = args.flavor + self.timeout: Optional[str] = args.timeout + self.detach: bool = args.detach + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + + def run(self) -> None: + api = HfApi(token=self.token) + job = api.run_job( + image=self.image, + command=self.command, + env=self.env, + secrets=self.secrets, + flavor=self.flavor, + timeout=self.timeout, + namespace=self.namespace, + ) + # Always print the job ID to the user + print(f"Job started with ID: {job.id}") + print(f"View at: {job.url}") + + if self.detach: + return + + # Now let's stream the logs + for log in api.fetch_job_logs(job_id=job.id): + print(log) + + +class LogsCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("logs", help="Fetch the logs of a Job") + run_parser.add_argument("job_id", type=str, help="Job ID") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the job is running. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + run_parser.set_defaults(func=LogsCommand) + + def __init__(self, args: Namespace) -> None: + self.job_id: str = args.job_id + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + + def run(self) -> None: + api = HfApi(token=self.token) + for log in api.fetch_job_logs(job_id=self.job_id, namespace=self.namespace): + print(log) + + +def _tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + terminal_width = max(os.get_terminal_size().columns, len(headers) * 12) + while len(headers) + sum(col_widths) > terminal_width: + col_to_minimize = col_widths.index(max(col_widths)) + col_widths[col_to_minimize] //= 2 + if len(headers) + sum(col_widths) <= terminal_width: + col_widths[col_to_minimize] = terminal_width - sum(col_widths) - len(headers) + col_widths[col_to_minimize] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + row_format_args = [ + str(x)[: col_width - 3] + "..." if len(str(x)) > col_width else str(x) + for x, col_width in zip(row, col_widths) + ] + lines.append(row_format.format(*row_format_args)) + return "\n".join(lines) + + +class PsCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("ps", help="List Jobs") + run_parser.add_argument( + "-a", + "--all", + action="store_true", + help="Show all Jobs (default shows just running)", + ) + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace from where it lists the jobs. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", + type=str, + help="A User Access Token generated from https://huggingface.co/settings/tokens", + ) + # Add Docker-style filtering argument + run_parser.add_argument( + "-f", + "--filter", + action="append", + default=[], + help="Filter output based on conditions provided (format: key=value)", + ) + # Add option to format output + run_parser.add_argument( + "--format", + type=str, + help="Format output using a custom template", + ) + run_parser.set_defaults(func=PsCommand) + + def __init__(self, args: Namespace) -> None: + self.all: bool = args.all + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + self.format: Optional[str] = args.format + self.filters: Dict[str, str] = {} + + # Parse filter arguments (key=value pairs) + for f in args.filter: + if "=" in f: + key, value = f.split("=", 1) + self.filters[key.lower()] = value + else: + print(f"Warning: Ignoring invalid filter format '{f}'. Use key=value format.") + + def run(self) -> None: + """ + Fetch and display job information for the current user. + Uses Docker-style filtering with -f/--filter flag and key=value pairs. + """ + try: + api = HfApi(token=self.token) + + # Fetch jobs data + jobs = api.list_jobs(namespace=self.namespace) + + # Define table headers + table_headers = ["JOB ID", "IMAGE/SPACE", "COMMAND", "CREATED", "STATUS"] + + # Process jobs data + rows = [] + + for job in jobs: + # Extract job data for filtering + status = job.status.stage if job.status else "UNKNOWN" + + # Skip job if not all jobs should be shown and status doesn't match criteria + if not self.all and status not in ("RUNNING", "UPDATING"): + continue + + # Extract job ID + job_id = job.id + + # Extract image or space information + image_or_space = job.docker_image or "N/A" + + # Extract and format command + command = job.command or [] + command_str = " ".join(command) if command else "N/A" + + # Extract creation time + created_at = job.created_at or "N/A" + + # Create a dict with all job properties for filtering + job_properties = { + "id": job_id, + "image": image_or_space, + "status": status.lower(), + "command": command_str, + } + + # Check if job matches all filters + if not self._matches_filters(job_properties): + continue + + # Create row + rows.append([job_id, image_or_space, command_str, created_at, status]) + + # Handle empty results + if not rows: + filters_msg = "" + if self.filters: + filters_msg = f" matching filters: {', '.join([f'{k}={v}' for k, v in self.filters.items()])}" + + print(f"No jobs found{filters_msg}") + return + + # Apply custom format if provided or use default tabular format + self._print_output(rows, table_headers) + + except requests.RequestException as e: + print(f"Error fetching jobs data: {e}") + except (KeyError, ValueError, TypeError) as e: + print(f"Error processing jobs data: {e}") + except Exception as e: + print(f"Unexpected error - {type(e).__name__}: {e}") + + def _matches_filters(self, job_properties: Dict[str, str]) -> bool: + """Check if job matches all specified filters.""" + for key, pattern in self.filters.items(): + # Check if property exists + if key not in job_properties: + return False + + # Support pattern matching with wildcards + if "*" in pattern or "?" in pattern: + # Convert glob pattern to regex + regex_pattern = pattern.replace("*", ".*").replace("?", ".") + if not re.search(f"^{regex_pattern}$", job_properties[key], re.IGNORECASE): + return False + # Simple substring matching + elif pattern.lower() not in job_properties[key].lower(): + return False + + return True + + def _print_output(self, rows, headers): + """Print output according to the chosen format.""" + if self.format: + # Custom template formatting (simplified) + template = self.format + for row in rows: + line = template + for i, field in enumerate(["id", "image", "command", "created", "status"]): + placeholder = f"{{{{.{field}}}}}" + if placeholder in line: + line = line.replace(placeholder, str(row[i])) + print(line) + else: + # Default tabular format + print( + _tabulate( + rows, + headers=headers, + ) + ) + + +class InspectCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("inspect", help="Display detailed information on one or more Jobs") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the job is running. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + run_parser.add_argument("job_ids", nargs="...", help="The jobs to inspect") + run_parser.set_defaults(func=InspectCommand) + + def __init__(self, args: Namespace) -> None: + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + self.job_ids: List[str] = args.job_ids + + def run(self) -> None: + api = HfApi(token=self.token) + jobs = [api.inspect_job(job_id=job_id, namespace=self.namespace) for job_id in self.job_ids] + print(json.dumps([asdict(job) for job in jobs], indent=4, default=str)) + + +class CancelCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction) -> None: + run_parser = parser.add_parser("cancel", help="Cancel a Job") + run_parser.add_argument("job_id", type=str, help="Job ID") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the job is running. Defaults to the current user's namespace.", + ) + run_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + run_parser.set_defaults(func=CancelCommand) + + def __init__(self, args: Namespace) -> None: + self.job_id: str = args.job_id + self.namespace = args.namespace + self.token: Optional[str] = args.token + + def run(self) -> None: + api = HfApi(token=self.token) + api.cancel_job(job_id=self.job_id, namespace=self.namespace) + + +class UvCommand(BaseHuggingfaceCLICommand): + """Run UV scripts on Hugging Face infrastructure.""" + + @staticmethod + def register_subcommand(parser): + """Register UV run subcommand.""" + uv_parser = parser.add_parser( + "uv", + help="Run UV scripts (Python with inline dependencies) on HF infrastructure", + ) + + subparsers = uv_parser.add_subparsers(dest="uv_command", help="UV commands", required=True) + + # Run command only + run_parser = subparsers.add_parser( + "run", + help="Run a UV script (local file or URL) on HF infrastructure", + ) + run_parser.add_argument("script", help="UV script to run (local file or URL)") + run_parser.add_argument("script_args", nargs="...", help="Arguments for the script", default=[]) + run_parser.add_argument("--image", type=str, help="Use a custom Docker image with `uv` installed.") + run_parser.add_argument( + "--repo", + help="Repository name for the script (creates ephemeral if not specified)", + ) + run_parser.add_argument( + "--flavor", + type=str, + help=f"Flavor for the hardware, as in HF Spaces. Defaults to `cpu-basic`. Possible values: {', '.join(SUGGESTED_FLAVORS)}.", + ) + run_parser.add_argument("-e", "--env", action="append", help="Environment variables") + run_parser.add_argument("-s", "--secrets", action="append", help="Secret environment variables") + run_parser.add_argument("--env-file", type=str, help="Read in a file of environment variables.") + run_parser.add_argument( + "--secrets-file", + type=str, + help="Read in a file of secret environment variables.", + ) + run_parser.add_argument("--timeout", type=str, help="Max duration (e.g., 30s, 5m, 1h)") + run_parser.add_argument("-d", "--detach", action="store_true", help="Run in background") + run_parser.add_argument( + "--namespace", + type=str, + help="The namespace where the Job will be created. Defaults to the current user's namespace.", + ) + run_parser.add_argument("--token", type=str, help="HF token") + # UV options + run_parser.add_argument("--with", action="append", help="Run with the given packages installed", dest="with_") + run_parser.add_argument( + "-p", "--python", type=str, help="The Python interpreter to use for the run environment" + ) + run_parser.set_defaults(func=UvCommand) + + def __init__(self, args: Namespace) -> None: + """Initialize the command with parsed arguments.""" + self.script = args.script + self.script_args = args.script_args + self.dependencies = args.with_ + self.python = args.python + self.image = args.image + self.env: dict[str, Optional[str]] = {} + if args.env_file: + self.env.update(load_dotenv(Path(args.env_file).read_text())) + for env_value in args.env or []: + self.env.update(load_dotenv(env_value)) + self.secrets: dict[str, Optional[str]] = {} + if args.secrets_file: + self.secrets.update(load_dotenv(Path(args.secrets_file).read_text())) + for secret in args.secrets or []: + self.secrets.update(load_dotenv(secret)) + self.flavor: Optional[SpaceHardware] = args.flavor + self.timeout: Optional[str] = args.timeout + self.detach: bool = args.detach + self.namespace: Optional[str] = args.namespace + self.token: Optional[str] = args.token + self._repo = args.repo + + def run(self) -> None: + """Execute UV command.""" + logging.set_verbosity(logging.INFO) + api = HfApi(token=self.token) + job = api.run_uv_job( + script=self.script, + script_args=self.script_args, + dependencies=self.dependencies, + python=self.python, + image=self.image, + env=self.env, + secrets=self.secrets, + flavor=self.flavor, + timeout=self.timeout, + namespace=self.namespace, + _repo=self._repo, + ) + + # Always print the job ID to the user + print(f"Job started with ID: {job.id}") + print(f"View at: {job.url}") + + if self.detach: + return + + # Now let's stream the logs + for log in api.fetch_job_logs(job_id=job.id): + print(log) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e2a32dd14a..4e4090e864 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -19,6 +19,7 @@ import json import re import struct +import time import warnings from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor @@ -27,6 +28,7 @@ from functools import wraps from itertools import islice from pathlib import Path +from textwrap import dedent from typing import ( TYPE_CHECKING, Any, @@ -65,6 +67,7 @@ _warn_on_overwriting_operations, ) from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType +from ._jobs_api import JobInfo from ._space_api import SpaceHardware, SpaceRuntime, SpaceStorage, SpaceVariable from ._upload_large_folder import upload_large_folder_internal from .community import ( @@ -9940,6 +9943,506 @@ def auth_check( r = get_session().get(path, headers=headers) hf_raise_for_status(r) + def run_job( + self, + *, + image: str, + command: List[str], + env: Optional[Dict[str, Any]] = None, + secrets: Optional[Dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Run compute Jobs on Hugging Face infrastructure. + + Args: + image (`str`): + The Docker image to use. + Examples: `"ubuntu"`, `"python:3.12"`, `"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"`. + Example with an image from a Space: `"hf.co/spaces/lhoestq/duckdb"`. + + command (`List[str]`): + The command to run. Example: `["echo", "hello"]`. + + env (`Dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`Dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + Run your first Job: + + ```python + >>> from huggingface_hub import run_job + >>> run_job("python:3.12", ["python", "-c" ,"print('Hello from HF compute!')"]) + ``` + + Run a GPU Job: + + ```python + >>> from huggingface_hub import run_job + >>> image = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel" + >>> command = ["python", "-c", "import torch; print(f"This code ran with the following GPU: {torch.cuda.get_device_name()}")"] + >>> run_job(image, command, flavor="a10g-small") + ``` + + """ + if flavor is None: + flavor = SpaceHardware.CPU_BASIC + + # prepare payload to send to HF Jobs API + input_json: Dict[str, Any] = { + "command": command, + "arguments": [], + "environment": env or {}, + "flavor": flavor, + } + # secrets are optional + if secrets: + input_json["secrets"] = secrets + # timeout is optional + if timeout: + time_units_factors = {"s": 1, "m": 60, "h": 3600, "d": 3600 * 24} + if isinstance(timeout, str) and timeout[-1] in time_units_factors: + input_json["timeoutSeconds"] = int(float(timeout[:-1]) * time_units_factors[timeout[-1]]) + else: + input_json["timeoutSeconds"] = int(timeout) + # input is either from docker hub or from HF spaces + for prefix in ( + "https://huggingface.co/spaces/", + "https://hf.co/spaces/", + "huggingface.co/spaces/", + "hf.co/spaces/", + ): + if image.startswith(prefix): + input_json["spaceId"] = image[len(prefix) :] + break + else: + input_json["dockerImage"] = image + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().post( + f"https://huggingface.co/api/jobs/{namespace}", + json=input_json, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + job_info = response.json() + return JobInfo(**job_info, endpoint=self.endpoint) + + def fetch_job_logs( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[str]: + """ + Fetch all the logs from a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import fetch_job_logs, run_job + >>> job = run_job("python:3.12", ["python", "-c" ,"print('Hello from HF compute!')"]) + >>> for log in fetch_job_logs(job.job_id): + ... print(log) + Hello from HF compute! + ``` + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + logging_finished = logging_started = False + job_finished = False + # - We need to retry because sometimes the /logs doesn't return logs when the job just started. + # (for example it can return only two lines: one for "Job started" and one empty line) + # - Timeouts can happen in case of build errors + # - ChunkedEncodingError can happen in case of stopped logging in the middle of streaming + # - Infinite empty log stream can happen in case of build error + # (the logs stream is infinite and empty except for the Job started message) + # - there is a ": keep-alive" every 30 seconds + + # We don't use http_backoff since we need to check ourselves if ConnectionError.__context__ is a TimeoutError + max_retries = 5 + min_wait_time = 1 + max_wait_time = 10 + sleep_time = 0 + for _ in range(max_retries): + time.sleep(sleep_time) + sleep_time = min(max_wait_time, max(min_wait_time, sleep_time * 2)) + try: + resp = get_session().get( + f"https://huggingface.co/api/jobs/{namespace}/{job_id}/logs", + headers=self._build_hf_headers(token=token), + stream=True, + timeout=120, + ) + log = None + for line in resp.iter_lines(chunk_size=1): + line = line.decode("utf-8") + if line and line.startswith("data: {"): + data = json.loads(line[len("data: ") :]) + # timestamp = data["timestamp"] + if not data["data"].startswith("===== Job started"): + logging_started = True + log = data["data"] + yield log + logging_finished = logging_started + except requests.exceptions.ChunkedEncodingError: + # Response ended prematurely + break + except KeyboardInterrupt: + break + except requests.exceptions.ConnectionError as err: + is_timeout = err.__context__ and isinstance(getattr(err.__context__, "__cause__", None), TimeoutError) + if logging_started or not is_timeout: + raise + if logging_finished or job_finished: + break + job_status = ( + get_session() + .get( + f"https://huggingface.co/api/jobs/{namespace}/{job_id}", + headers=self._build_hf_headers(token=token), + ) + .json() + ) + if "status" in job_status and job_status["status"]["stage"] not in ("RUNNING", "UPDATING"): + job_finished = True + + def list_jobs( + self, + *, + timeout: Optional[int] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> List[JobInfo]: + """ + List compute Jobs on Hugging Face infrastructure. + + Args: + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + + namespace (`str`, *optional*): + The namespace from where it lists the jobs. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}", + headers=self._build_hf_headers(token=token), + timeout=timeout, + ) + response.raise_for_status() + return [JobInfo(**job_info, endpoint=self.endpoint) for job_info in response.json()] + + def inspect_job( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> JobInfo: + """ + Inspect a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import inspect_job, run_job + >>> job = run_job("python:3.12", ["python", "-c" ,"print('Hello from HF compute!')"]) + >>> inspect_job(job.job_id) + JobInfo( + id='68780d00bbe36d38803f645f', + created_at=datetime.datetime(2025, 7, 16, 20, 35, 12, 808000, tzinfo=datetime.timezone.utc), + docker_image='python:3.12', + space_id=None, + command=['python', '-c', "print('Hello from HF compute!')"], + arguments=[], + environment={}, + secrets={}, + flavor='cpu-basic', + status=JobStatus(stage='RUNNING', message=None) + ) + ``` + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + response = get_session().get( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}", + headers=self._build_hf_headers(token=token), + ) + response.raise_for_status() + return JobInfo(**response.json(), endpoint=self.endpoint) + + def cancel_job( + self, + *, + job_id: str, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """ + Cancel a compute Job on Hugging Face infrastructure. + + Args: + job_id (`str`): + ID of the Job. + + namespace (`str`, *optional*): + The namespace where the Job is running. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + """ + if namespace is None: + namespace = self.whoami(token=token)["name"] + get_session().post( + f"{self.endpoint}/api/jobs/{namespace}/{job_id}/cancel", + headers=self._build_hf_headers(token=token), + ).raise_for_status() + + @experimental + def run_uv_job( + self, + script: str, + *, + script_args: Optional[List[str]] = None, + dependencies: Optional[List[str]] = None, + python: Optional[str] = None, + image: Optional[str] = None, + env: Optional[Dict[str, Any]] = None, + secrets: Optional[Dict[str, Any]] = None, + flavor: Optional[SpaceHardware] = None, + timeout: Optional[Union[int, float, str]] = None, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + _repo: Optional[str] = None, + ) -> JobInfo: + """ + Run a UV script Job on Hugging Face infrastructure. + + Args: + script (`str`): + Path or URL of the UV script. + + script_args (`List[str]`, *optional*) + Arguments to pass to the script. + + dependencies (`List[str]`, *optional*) + Dependencies to use to run the UV script. + + python (`str`, *optional*) + Use a specific Python version. Default is 3.12. + + image (`str`, *optional*, defaults to "ghcr.io/astral-sh/uv:python3.12-bookworm-slim"): + Use a custom Docker image with `uv` installed. + + env (`Dict[str, Any]`, *optional*): + Defines the environment variables for the Job. + + secrets (`Dict[str, Any]`, *optional*): + Defines the secret environment variables for the Job. + + flavor (`str`, *optional*): + Flavor for the hardware, as in Hugging Face Spaces. See [`SpaceHardware`] for possible values. + Defaults to `"cpu-basic"`. + + timeout (`Union[int, float, str]`, *optional*): + Max duration for the Job: int/float with s (seconds, default), m (minutes), h (hours) or d (days). + Example: `300` or `"5m"` for 5 minutes. + + namespace (`str`, *optional*): + The namespace where the Job will be created. Defaults to the current user's namespace. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Example: + + ```python + >>> from huggingface_hub import run_uv_job + >>> script = "https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/trl/scripts/sft.py" + >>> run_uv_job(script, dependencies=["trl"], flavor="a10g-small") + ``` + """ + image = image or "ghcr.io/astral-sh/uv:python3.12-bookworm-slim" + env = env or {} + secrets = secrets or {} + + # Build command + uv_args = [] + if dependencies: + for dependency in dependencies: + uv_args += ["--with", dependency] + if python: + uv_args += ["--python", python] + script_args = script_args or [] + + if namespace is None: + namespace = self.whoami(token=token)["name"] + + if script.startswith("http://") or script.startswith("https://"): + # Direct URL execution - no upload needed + command = ["uv", "run"] + uv_args + [script] + script_args + else: + # Local file - upload to HF + script_path = Path(script) + filename = script_path.name + # Parse repo + if _repo: + repo_id = _repo + if "/" not in repo_id: + repo_id = f"{namespace}/{repo_id}" + repo_id = _repo + else: + repo_id = f"{namespace}/hf-cli-jobs-uv-run-scripts" + + # Create repo if needed + try: + self.repo_info(repo_id, repo_type="dataset") + logger.debug(f"Using existing repository: {repo_id}") + except RepositoryNotFoundError: + logger.info(f"Creating repository: {repo_id}") + create_repo(repo_id, repo_type="dataset", private=True, exist_ok=True) + + # Upload script + logger.info(f"Uploading {script_path.name} to {repo_id}...") + with open(script_path, "r") as f: + script_content = f.read() + + self.upload_file( + path_or_fileobj=script_content.encode(), + path_in_repo=filename, + repo_id=repo_id, + repo_type="dataset", + ) + + script_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" + repo_url = f"https://huggingface.co/datasets/{repo_id}" + + logger.debug(f"✓ Script uploaded to: {repo_url}/blob/main/{filename}") + + # Create and upload minimal README + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S UTC") + readme_content = dedent( + f""" + --- + tags: + - hf-cli-jobs-uv-script + - ephemeral + viewer: false + --- + + # UV Script: {filename} + + Executed via `huggingface-cli jobs uv run` on {timestamp} + + ## Run this script + + ```bash + huggingface-cli jobs uv run {filename} + ``` + + --- + *Created with [huggingface-cli jobs](https://github.com/huggingface/huggingface-cli jobs)* + """ + ) + self.upload_file( + path_or_fileobj=readme_content.encode(), + path_in_repo="README.md", + repo_id=repo_id, + repo_type="dataset", + ) + + secrets["UV_SCRIPT_HF_TOKEN"] = token or self.token or get_token() + env["UV_SCRIPT_URL"] = script_url + + pre_command = ( + dedent( + """ + import urllib.request + import os + from pathlib import Path + o = urllib.request.build_opener() + o.addheaders = [("Authorization", "Bearer " + os.environ["UV_SCRIPT_HF_TOKEN"])] + Path("/tmp/script.py").write_bytes(o.open(os.environ["UV_SCRIPT_URL"]).read()) + """ + ) + .strip() + .replace('"', r"\"") + .split("\n") + ) + pre_command = ["python", "-c", '"' + "; ".join(pre_command) + '"'] + command = ["uv", "run"] + uv_args + ["/tmp/script.py"] + script_args + command = ["bash", "-c", " ".join(pre_command) + " && " + " ".join(command)] + + # Create RunCommand args + return self.run_job( + image=image, + command=command, + env=env, + secrets=secrets, + flavor=flavor, + timeout=timeout, + namespace=namespace, + token=token, + ) + def _parse_revision_from_pr_url(pr_url: str) -> str: """Safely parse revision number from a PR url. @@ -10096,3 +10599,11 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: list_organization_members = api.list_organization_members list_user_followers = api.list_user_followers list_user_following = api.list_user_following + +# Jobs API +run_job = api.run_job +fetch_job_logs = api.fetch_job_logs +list_jobs = api.list_jobs +inspect_job = api.inspect_job +cancel_job = api.cancel_job +run_uv_job = api.run_uv_job diff --git a/src/huggingface_hub/utils/_dotenv.py b/src/huggingface_hub/utils/_dotenv.py new file mode 100644 index 0000000000..6e3c13d611 --- /dev/null +++ b/src/huggingface_hub/utils/_dotenv.py @@ -0,0 +1,51 @@ +# AI-generated module (ChatGPT) +import re +from typing import Dict + + +def load_dotenv(dotenv_str: str) -> Dict[str, str]: + """ + Parse a DOTENV-format string and return a dictionary of key-value pairs. + Handles quoted values, comments, export keyword, and blank lines. + """ + env: Dict[str, str] = {} + line_pattern = re.compile( + r""" + ^\s* + (?:export\s+)? # optional export + ([A-Za-z_][A-Za-z0-9_]*) # key + \s*=\s* + ( # value group + (?: + '(?:\\'|[^'])*' # single-quoted value + | "(?:\\"|[^"])*" # double-quoted value + | [^#\n\r]+? # unquoted value + ) + )? + \s*(?:\#.*)?$ # optional inline comment + """, + re.VERBOSE, + ) + + for line in dotenv_str.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue # Skip comments and empty lines + + match = line_pattern.match(line) + if not match: + continue # Skip malformed lines + + key, raw_val = match.group(1), match.group(2) or "" + val = raw_val.strip() + + # Remove surrounding quotes if quoted + if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")): + val = val[1:-1] + val = val.replace(r"\n", "\n").replace(r"\t", "\t").replace(r"\"", '"').replace(r"\\", "\\") + if raw_val.startswith('"'): + val = val.replace(r"\$", "$") # only in double quotes + + env[key] = val + + return env diff --git a/tests/test_cli.py b/tests/test_cli.py index 21ea90b409..1c25f42783 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,7 @@ from huggingface_hub.commands.delete_cache import DeleteCacheCommand from huggingface_hub.commands.download import DownloadCommand +from huggingface_hub.commands.jobs import JobsCommands, RunCommand from huggingface_hub.commands.repo_files import DeleteFilesSubCommand, RepoFilesCommand from huggingface_hub.commands.scan_cache import ScanCacheCommand from huggingface_hub.commands.tag import TagCommands @@ -837,3 +838,46 @@ def test_delete(self, delete_files_mock: Mock) -> None: assert kwargs == delete_files_args delete_files_mock.reset_mock() + + +class DummyResponse: + def __init__(self, json): + self._json = json + + def raise_for_status(self): + pass + + def json(self): + return self._json + + +class TestJobsCommand(unittest.TestCase): + def setUp(self) -> None: + """ + Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`. + """ + self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") + commands_parser = self.parser.add_subparsers() + JobsCommands.register_subcommand(commands_parser) + + @patch( + "requests.Session.post", + return_value=DummyResponse( + {"id": "my-job-id", "owner": {"id": "userid", "name": "my-username"}, "status": {"stage": "RUNNING"}} + ), + ) + @patch("huggingface_hub.hf_api.HfApi.whoami", return_value={"name": "my-username"}) + def test_run(self, whoami: Mock, requests_post: Mock) -> None: + input_args = ["jobs", "run", "--detach", "ubuntu", "echo", "hello"] + cmd = RunCommand(self.parser.parse_args(input_args)) + cmd.run() + assert requests_post.call_count == 1 + args, kwargs = requests_post.call_args_list[0] + assert args == ("https://huggingface.co/api/jobs/my-username",) + assert kwargs["json"] == { + "command": ["echo", "hello"], + "arguments": [], + "environment": {}, + "flavor": "cpu-basic", + "dockerImage": "ubuntu", + } diff --git a/tests/test_utils_dotenv.py b/tests/test_utils_dotenv.py new file mode 100644 index 0000000000..ae622262b4 --- /dev/null +++ b/tests/test_utils_dotenv.py @@ -0,0 +1,64 @@ +# AI-generated module (ChatGPT) +from huggingface_hub.utils._dotenv import load_dotenv + + +def test_basic_key_value(): + data = "KEY=value" + assert load_dotenv(data) == {"KEY": "value"} + + +def test_whitespace_and_comments(): + data = """ + # This is a comment + KEY = value # inline comment + EMPTY= + """ + assert load_dotenv(data) == {"KEY": "value", "EMPTY": ""} + + +def test_quoted_values(): + data = """ + SINGLE='single quoted' + DOUBLE="double quoted" + ESCAPED="line\\nbreak" + """ + assert load_dotenv(data) == {"SINGLE": "single quoted", "DOUBLE": "double quoted", "ESCAPED": "line\nbreak"} + + +def test_export_and_inline_comment(): + data = "export KEY=value # this is a comment" + assert load_dotenv(data) == {"KEY": "value"} + + +def test_ignore_invalid_lines(): + data = """ + this is not valid + KEY=value + """ + assert load_dotenv(data) == {"KEY": "value"} + + +def test_complex_quotes(): + data = r""" + QUOTED="some value with # not comment" + ESCAPE="escaped \$dollar and \\backslash" + """ + assert load_dotenv(data) == { + "QUOTED": "some value with # not comment", + "ESCAPE": "escaped $dollar and \\backslash", + } + + +def test_no_value(): + data = "NOVALUE=" + assert load_dotenv(data) == {"NOVALUE": ""} + + +def test_multiple_lines(): + data = """ + A=1 + B="two" + C='three' + D=4 + """ + assert load_dotenv(data) == {"A": "1", "B": "two", "C": "three", "D": "4"}