diff --git a/docs/source/tune_cli.rst b/docs/source/tune_cli.rst index 2c432b9894..5981fd8b80 100644 --- a/docs/source/tune_cli.rst +++ b/docs/source/tune_cli.rst @@ -39,19 +39,25 @@ to download files using the CLI. Download a model ---------------- -The ``tune download `` command downloads any model from the Hugging Face Hub. +The ``tune download `` command downloads any model from the Hugging Face or Kaggle Model Hub. .. list-table:: :widths: 30 60 * - \--output-dir - - Directory in which to save the model. + - Directory in which to save the model. Note: this is option not yet supported when `--source` is set to `kaggle`. * - \--output-dir-use-symlinks - To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be either duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if already exists) or downloaded from the Hub and not cached. * - \--hf-token - Hugging Face API token. Needed for gated models like Llama. * - \--ignore-patterns - If provided, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights. + * - \--source {huggingface,kaggle} + - If provided, downloads model weights from the provided on the designated source hub. + * - \--kaggle-username + - Kaggle username for authentication. Needed for private models or gated models like Llama2. + * - \--kaggle-api-key + - Kaggle API key. Needed for private models or gated models like Llama2. You can find your API key at https://kaggle.com/settings. .. code-block:: bash @@ -62,6 +68,13 @@ The ``tune download `` command downloads any model from the Hugging Face H ./model/model-00001-of-00002.bin ... +.. code-block:: bash + + $ tune download metaresearch/llama-3.2/pytorch/1b --source kaggle + Successfully downloaded model repo and wrote to the following locations: + /tmp/llama-3.2/pytorch/1b/tokenizer.model + /tmp/llama-3.2/pytorch/1b/params.json + /tmp/llama-3.2/pytorch/1b/consolidated.00.pth **Download a gated model** diff --git a/pyproject.toml b/pyproject.toml index c2920ff4d3..396531bcb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "huggingface_hub", "safetensors", + # Kaggle Integrations + "kagglehub", + # Tokenization "sentencepiece", "tiktoken", diff --git a/tests/torchtune/_cli/test_download.py b/tests/torchtune/_cli/test_download.py index 8a6d6ba0ab..3338adc870 100644 --- a/tests/torchtune/_cli/test_download.py +++ b/tests/torchtune/_cli/test_download.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import runpy import sys +from unittest import mock import pytest from tests.common import TUNE_PATH @@ -106,3 +108,251 @@ def test_gated_repo_error_with_token(self, capsys, monkeypatch, snapshot_downloa "Please ensure you have access to the repository and have provided the proper Hugging Face API token" not in out_err.err ) + + # only valid --source parameters supported (expect prompt for supported values) + def test_source_parameter(self, capsys, monkeypatch): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source invalid".split() + monkeypatch.setattr(sys, "argv", testargs) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr() + assert "argument --source: invalid choice: 'invalid'" in output.err + + def test_download_from_kaggle(self, capsys, monkeypatch, mocker, tmpdir): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr().out + assert "Successfully downloaded model repo" in output + + def test_download_from_kaggle_warn_when_output_dir_provided( + self, capsys, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --output-dir /requested/model/path".split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + + with pytest.warns( + UserWarning, + match="--output-dir flag is not supported for Kaggle model downloads", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr().out + assert "Successfully downloaded model repo" in output + + def test_download_from_kaggle_warn_when_ignore_patterns_provided( + self, capsys, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f'tune download {model} --source kaggle --ignore-patterns "*.glob-pattern"'.split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + + with pytest.warns( + UserWarning, + match="--ignore-patterns flag is not supported for Kaggle model downloads", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr().out + assert "Successfully downloaded model repo" in output + + # tests when --kaggle-username and --kaggle-api-key are provided as CLI args + def test_download_from_kaggle_when_credentials_provided( + self, capsys, monkeypatch, mocker, tmpdir + ): + expected_username = "username" + expected_api_key = "api_key" + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = ( + f"tune download {model} " + f"--source kaggle " + f"--kaggle-username {expected_username} " + f"--kaggle-api-key {expected_api_key}" + ).split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + set_kaggle_credentials_spy = mocker.patch( + "torchtune._cli.download.set_kaggle_credentials" + ) + + runpy.run_path(TUNE_PATH, run_name="__main__") + + set_kaggle_credentials_spy.assert_called_once_with( + expected_username, expected_api_key + ) + output = capsys.readouterr().out + assert ( + "TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key" + in output + ) + assert ( + "For more details, see https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate" + in output + ) + + # passes partial credentials with just --kaggle-username (expect fallback to environment variables) + @mock.patch.dict(os.environ, {"KAGGLE_KEY": "env_api_key"}) + def test_download_from_kaggle_when_partial_credentials_provided( + self, capsys, monkeypatch, mocker, tmpdir + ): + expected_username = "username" + expected_api_key = "env_api_key" + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username {expected_username}".split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + set_kaggle_credentials_spy = mocker.patch( + "torchtune._cli.download.set_kaggle_credentials" + ) + + runpy.run_path(TUNE_PATH, run_name="__main__") + + set_kaggle_credentials_spy.assert_called_once_with( + expected_username, expected_api_key + ) + output = capsys.readouterr().out + assert ( + "TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key" + in output + ) + assert ( + "For more details, see https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate" + in output + ) + + def test_download_from_kaggle_when_set_kaggle_credentials_throws( + self, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username u --kaggle-api-key k".split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + mocker.patch( + "torchtune._cli.download.set_kaggle_credentials", + side_effect=Exception("some error"), + ) + + with pytest.warns( + UserWarning, + match="Failed to set Kaggle credentials with error", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # KaggleApiHTTPError::Unauthorized without --kaggle-username and --kaggle-api-key (expect prompt for credentials) + def test_download_from_kaggle_unauthorized_credentials( + self, capsys, monkeypatch, mocker + ): + from http import HTTPStatus + + from kagglehub.exceptions import KaggleApiHTTPError + + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username username --kaggle-api-key key".split() + monkeypatch.setattr(sys, "argv", testargs) + + mock_model_download = mocker.patch("torchtune._cli.download.model_download") + mock_model_download.side_effect = KaggleApiHTTPError( + "Unauthorized", + response=mocker.MagicMock(status_code=HTTPStatus.UNAUTHORIZED), + ) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert ( + "Please ensure you have access to the model and have provided the proper Kaggle credentials" + in out_err.err + ) + assert "You can also set these to environment variables" in out_err.err + + # KaggleApiHTTPError::NotFound + def test_download_from_kaggle_model_not_found(self, capsys, monkeypatch, mocker): + from http import HTTPStatus + + from kagglehub.exceptions import KaggleApiHTTPError + + model = "mockorganizations/mockmodel/pytorch/mockvariation" + testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split() + monkeypatch.setattr(sys, "argv", testargs) + + mock_model_download = mocker.patch("torchtune._cli.download.model_download") + mock_model_download.side_effect = KaggleApiHTTPError( + "NotFound", response=mocker.MagicMock(status_code=HTTPStatus.NOT_FOUND) + ) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert f"'{model}' not found on the Kaggle Model Hub." in out_err.err + + # KaggleApiHTTPError::InternalServerError + def test_download_from_kaggle_api_error(self, capsys, monkeypatch, mocker): + from http import HTTPStatus + + from kagglehub.exceptions import KaggleApiHTTPError + + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split() + monkeypatch.setattr(sys, "argv", testargs) + + mock_model_download = mocker.patch("torchtune._cli.download.model_download") + mock_model_download.side_effect = KaggleApiHTTPError( + "InternalError", + response=mocker.MagicMock(status_code=HTTPStatus.INTERNAL_SERVER_ERROR), + ) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert "Failed to download" in out_err.err + + def test_download_from_kaggle_warn_on_nonmeta_pytorch_models( + self, monkeypatch, mocker, tmpdir + ): + model = "kaggle/kaggle-model-name/pytorch/1b" + testargs = f"tune download {model} --source kaggle".split() + monkeypatch.setattr(sys, "argv", testargs) + + # stub out model_download to guarantee success + mocker.patch( + "torchtune._cli.download.model_download", + return_value=tmpdir, + ) + + with pytest.warns(UserWarning, match="may not be compatible with torchtune"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + def test_download_from_kaggle_warn_on_nonpytorch_nontransformers_model( + self, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/some-model/some-madeup-framework/1b" + testargs = f"tune download {model} --source kaggle".split() + monkeypatch.setattr(sys, "argv", testargs) + + # stub out model_download to guarantee success + mocker.patch( + "torchtune._cli.download.model_download", + return_value=tmpdir, + ) + + with pytest.warns(UserWarning, match="may not be compatible with torchtune"): + runpy.run_path(TUNE_PATH, run_name="__main__") diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index 82b4935c01..a60cf370ab 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -7,12 +7,20 @@ import argparse import os import textwrap +import traceback +from http import HTTPStatus from pathlib import Path from typing import Literal, Union +from warnings import warn from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + +from kagglehub import model_download +from kagglehub.auth import set_kaggle_credentials +from kagglehub.exceptions import KaggleApiHTTPError +from kagglehub.handle import parse_model_handle from torchtune._cli.subcommand import Subcommand @@ -25,8 +33,8 @@ def __init__(self, subparsers: argparse._SubParsersAction): "download", prog="tune download", usage="tune download [OPTIONS]", - help="Download a model from the Hugging Face Hub.", - description="Download a model from the Hugging Face Hub.", + help="Download a model from the Hugging Face Hub or Kaggle Model Hub.", + description="Download a model from the Hugging Face Hub or Kaggle Model Hub.", epilog=textwrap.dedent( """\ examples: @@ -46,7 +54,16 @@ def __init__(self, subparsers: argparse._SubParsersAction): /tmp/model/model-00001-of-00002.bin ... - For a list of all models, visit the Hugging Face Hub https://huggingface.co/models. + # Download a model from the Kaggle Model Hub + $ tune download metaresearch/llama-3.2/pytorch/1b --source kaggle + Successfully downloaded model repo and wrote to the following locations: + /tmp/llama-3.2/pytorch/1b/tokenizer.model + /tmp/llama-3.2/pytorch/1b/params.json + /tmp/llama-3.2/pytorch/1b/consolidated.00.pth + ... + + For a list of all models, visit the Hugging Face Hub + https://huggingface.co/models or Kaggle Model Hub https://kaggle.com/models. """ ), formatter_class=argparse.RawTextHelpFormatter, @@ -59,7 +76,7 @@ def _add_arguments(self) -> None: self._parser.add_argument( "repo_id", type=str, - help="Name of the repository on Hugging Face Hub.", + help="Name of the repository on Hugging Face Hub or model handle on Kaggle Model Hub.", ) self._parser.add_argument( "--output-dir", @@ -93,10 +110,38 @@ def _add_arguments(self) -> None: required=False, default="*.safetensors", help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring " - "safetensors files to avoid downloading duplicate weights.", + "safetensors files to avoid downloading duplicate weights. Only supported for Hugging Face Hub models.", + ) + self._parser.add_argument( + "--source", + type=str, + required=False, + default="huggingface", + choices=["huggingface", "kaggle"], + help="If provided, downloads model weights from the provided repo_id on the designated source hub.", + ) + self._parser.add_argument( + "--kaggle-username", + type=str, + required=False, + help="Kaggle username for authentication. Needed for private models or gated models like Llama2.", + ) + self._parser.add_argument( + "--kaggle-api-key", + type=str, + required=False, + help="Kaggle API key. Needed for private models or gated models like Llama2. You can find your " + "API key at https://kaggle.com/settings.", ) def _download_cmd(self, args: argparse.Namespace) -> None: + # Note: we're relying on argparse to validate if the provided args.source is supported + if args.source == "huggingface": + return self._download_from_huggingface(args) + if args.source == "kaggle": + return self._download_from_kaggle(args) + + def _download_from_huggingface(self, args: argparse.Namespace) -> None: """Downloads a model from the Hugging Face Hub.""" # Download the tokenizer and PyTorch model files @@ -148,8 +193,6 @@ def _download_cmd(self, args: argparse.Namespace) -> None: f"Repository '{args.repo_id}' not found on the Hugging Face Hub." ) except Exception as e: - import traceback - tb = traceback.format_exc() msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}" self._parser.error(msg) @@ -159,3 +202,104 @@ def _download_cmd(self, args: argparse.Namespace) -> None: *list(Path(true_output_dir).iterdir()), sep="\n", ) + + def _download_from_kaggle(self, args: argparse.Namespace) -> None: + """Downloads a model from the Kaggle Model Hub.""" + + # Note: Kaggle doesn't actually use the "repository" terminology, but we still reuse args.repo_id here for simplicity + model_handle = args.repo_id + self._validate_kaggle_model_handle(model_handle) + self._set_kaggle_credentials(args) + + # kagglehub doesn't currently support `local_dir` and `ignore_patterns` like huggingface_hub + if args.output_dir: + warn( + "--output-dir flag is not supported for Kaggle model downloads. " + "This argument will be ignored." + ) + if args.ignore_patterns: + warn( + "--ignore-patterns flag is not supported for Kaggle model downloads. " + "This argument will be ignored." + ) + + try: + output_dir = model_download(model_handle) + print( + "Successfully downloaded model repo and wrote to the following locations:", + *list(Path(output_dir).iterdir()), + sep="\n", + ) + except KaggleApiHTTPError as e: + if e.response.status_code in { + HTTPStatus.UNAUTHORIZED, + HTTPStatus.FORBIDDEN, + }: + self._parser.error( + "It looks like you are trying to access a gated model. Please ensure you " + "have access to the model and have provided the proper Kaggle credentials " + "using the options `--kaggle-username` and `--kaggle-api-key`. You can also " + "set these to environment variables as detailed in " + "https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate." + ) + elif e.response.status_code == HTTPStatus.NOT_FOUND: + self._parser.error( + f"'{model_handle}' not found on the Kaggle Model Hub." + ) + tb = traceback.format_exc() + msg = f"Failed to download {model_handle} with error: '{e}' and traceback: {tb}" + self._parser.error(msg) + except Exception as e: + tb = traceback.format_exc() + msg = f"Failed to download {model_handle} with error: '{e}' and traceback: {tb}" + self._parser.error(msg) + + def _validate_kaggle_model_handle(self, handle: str) -> None: + try: + parsed_handle = parse_model_handle(handle) + if ( + parsed_handle.framework == "pytorch" + and parsed_handle.owner != "metaresearch" + ): + warn( + f"Requested PyTorch model {handle} was not published from Meta, and therefore " + "may not be compatible with torchtune." + ) + if parsed_handle.framework not in {"pytorch", "transformers"}: + warn( + f"Requested model {handle} is neither a PyTorch nor a Transformers model, and " + "therefore may not be compatible with torchtune." + ) + except Exception as e: + msg = f"Failed to validate {handle} with error {e}." + self._parser.error(msg) + + def _set_kaggle_credentials(self, args: argparse.Namespace): + try: + if args.kaggle_username or args.kaggle_api_key: + print( + "TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key " + "arguments by storing them as the environment variables KAGGLE_USERNAME and " + "KAGGLE_KEY, respectively. For more details, see " + "https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate" + ) + + # Fallback to known Kaggle environment variables in case user omits one + # of the CLI arguments. Note, there's no need to fallback when both + # --kaggle-username and --kaggle-api-key are omitted since kagglehub + # will check the environment variables itself. + kaggle_username = ( + args.kaggle_username + if args.kaggle_username + else os.environ.get("KAGGLE_USERNAME") + ) + kaggle_api_key = ( + args.kaggle_api_key + if args.kaggle_api_key + else os.environ.get("KAGGLE_KEY") + ) + set_kaggle_credentials(kaggle_username, kaggle_api_key) + except Exception as e: + msg = f"Failed to set Kaggle credentials with error: '{e}'" + # not all Kaggle downloads require credentials, so there's no need to terminate + warn(msg)