From 6e5ebbcff2915df0992e3c1ed06faa7c72a92bc9 Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Wed, 6 Nov 2024 19:44:43 +0000 Subject: [PATCH 1/8] feat(cli): add support for kaggle downloads --- pyproject.toml | 3 + torchtune/_cli/download.py | 128 +++++++++++++++++++++++++++++++++++-- 2 files changed, 124 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c2920ff4d3..396531bcb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ dependencies = [ "huggingface_hub", "safetensors", + # Kaggle Integrations + "kagglehub", + # Tokenization "sentencepiece", "tiktoken", diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index 82b4935c01..7fc3b615b4 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -7,12 +7,20 @@ import argparse import os import textwrap +import traceback +from http import HTTPStatus from pathlib import Path from typing import Literal, Union +from warnings import warn from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + +from kagglehub import model_download +from kagglehub.auth import set_kaggle_credentials +from kagglehub.exceptions import KaggleApiHTTPError +from kagglehub.handle import parse_model_handle from torchtune._cli.subcommand import Subcommand @@ -25,8 +33,8 @@ def __init__(self, subparsers: argparse._SubParsersAction): "download", prog="tune download", usage="tune download [OPTIONS]", - help="Download a model from the Hugging Face Hub.", - description="Download a model from the Hugging Face Hub.", + help="Download a model from the Hugging Face Hub or Kaggle Model Hub.", + description="Download a model from the Hugging Face Hub or Kaggle Model Hub.", epilog=textwrap.dedent( """\ examples: @@ -46,7 +54,15 @@ def __init__(self, subparsers: argparse._SubParsersAction): /tmp/model/model-00001-of-00002.bin ... - For a list of all models, visit the Hugging Face Hub https://huggingface.co/models. + # Download a model from the Kaggle Model Hub + $ tune download metaresearch/llama-3.2/pytorch/1b --source kaggle + Successfully downloaded model repo and wrote to the following locations: + /tmp/llama-3.2/pytorch/1b/tokenizer.model + /tmp/llama-3.2/pytorch/1b/params.json + /tmp/llama-3.2/pytorch/1b/consolidated.00.pth + ... + + For a list of all models, visit the Hugging Face Hub https://huggingface.co/models or Kaggle Model Hub https://kaggle.com/models. """ ), formatter_class=argparse.RawTextHelpFormatter, @@ -59,7 +75,7 @@ def _add_arguments(self) -> None: self._parser.add_argument( "repo_id", type=str, - help="Name of the repository on Hugging Face Hub.", + help="Name of the repository on Hugging Face Hub or model handle on Kaggle Model Hub.", ) self._parser.add_argument( "--output-dir", @@ -93,10 +109,37 @@ def _add_arguments(self) -> None: required=False, default="*.safetensors", help="If provided, files matching any of the patterns are not downloaded. Defaults to ignoring " - "safetensors files to avoid downloading duplicate weights.", + "safetensors files to avoid downloading duplicate weights. Only supported for Hugging Face Hub models.", + ) + self._parser.add_argument( + "--source", + type=str, + required=False, + default="huggingface", + choices=["huggingface", "kaggle"], + help="If provided, downloads model weights from the provided repo_id on the designated source hub.", + ) + self._parser.add_argument( + "--kaggle-username", + type=str, + required=False, + help="Kaggle username for authentication. Needed for private models or gated models like Llama2.", + ) + self._parser.add_argument( + "--kaggle-api-key", + type=str, + required=False, + help="Kaggle API key. Needed for private models or gated models like Llama2. You can find your API key at https://kaggle.com/settings.", ) def _download_cmd(self, args: argparse.Namespace) -> None: + # Note: we're relying on argparse to validate if the provided args.source is supported + if args.source == "huggingface": + return self._download_from_huggingface(args) + if args.source == "kaggle": + return self._download_from_kaggle(args) + + def _download_from_huggingface(self, args: argparse.Namespace) -> None: """Downloads a model from the Hugging Face Hub.""" # Download the tokenizer and PyTorch model files @@ -148,8 +191,6 @@ def _download_cmd(self, args: argparse.Namespace) -> None: f"Repository '{args.repo_id}' not found on the Hugging Face Hub." ) except Exception as e: - import traceback - tb = traceback.format_exc() msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}" self._parser.error(msg) @@ -159,3 +200,76 @@ def _download_cmd(self, args: argparse.Namespace) -> None: *list(Path(true_output_dir).iterdir()), sep="\n", ) + + def _download_from_kaggle(self, args: argparse.Namespace) -> None: + """Downloads a model from the Kaggle Model Hub.""" + + # Note: Kaggle doesn't actually use the "repository" terminology, but we still reuse args.repo_id here for simplicity + model_handle = args.repo_id + self._validate_kaggle_model_handle(model_handle) + + # kagglehub doesn't currently support `local_dir` and `ignore_patterns` like huggingface_hub + if args.output_dir is not None: + warn("--output-dir flag is not supported for Kaggle model downloads.") + + if args.kaggle_username is not None and args.kaggle_api_key is not None: + set_kaggle_credentials(args.kaggle_username, args.kaggle_api_key) + elif args.kaggle_username is not None and args.kaggle_api_key is None: + self._parser.error( + "Missing --kaggle-api-key. Please provide both your Kaggle username " + "and API key. Find your API key at https://kaggle.com/settings." + ) + elif args.kaggle_username is None and args.kaggle_api_key is not None: + self._parser.error( + "Missing --kaggle-username. Please provide both your Kaggle username " + "and API key." + ) + + try: + output_dir = model_download(model_handle) + print( + "Successfully downloaded model repo and wrote to the following locations:", + *list(Path(output_dir).iterdir()), + sep="\n", + ) + except KaggleApiHTTPError as e: + if e.response.status_code in { + HTTPStatus.UNAUTHORIZED, + HTTPStatus.FORBIDDEN, + }: + self._parser.error( + "It looks like you are trying to access a gated model. Please ensure you " + "have access to the model and have provided the proper Kaggle credentials " + "using the options `--kaggle-username` and `--kaggle-api-key`. You can also " + "set these to environment variables as detailed in " + "https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate." + ) + elif e.response.status_code == HTTPStatus.NOT_FOUND: + self._parser.error( + f"'{model_handle}' not found on the Kaggle Model Hub." + ) + tb = traceback.format_exc() + msg = f"Failed to download {model_handle} with error: '{e}' and traceback: {tb}" + self._parser.error(msg) + except Exception as e: + tb = traceback.format_exc() + msg = f"Failed to download {model_handle} with error: '{e}' and traceback: {tb}" + self._parser.error(msg) + + def _validate_kaggle_model_handle(self, handle: str) -> None: + try: + parsed_handle = parse_model_handle(handle) + if ( + parsed_handle.framework == "pytorch" + and parsed_handle.owner != "metaresearch" + ): + warn( + f"Requested PyTorch model {handle} was not published from Meta, and therefore may not be compatible with torchtune." + ) + if parsed_handle.framework not in {"pytorch", "transformers"}: + warn( + f"Requested model {handle} is neither a PyTorch nor a Transformers model, and therefore may not be compatible with torchtune." + ) + except Exception as e: + msg = f"Failed to validate {handle} with error {e}." + self._parser.error(msg) From ca4787da3b56401de85716779d0b66e52dd34597 Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Wed, 6 Nov 2024 19:45:19 +0000 Subject: [PATCH 2/8] tests(cli): add test cases for kaggle downloads --- tests/torchtune/_cli/test_download.py | 175 ++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/tests/torchtune/_cli/test_download.py b/tests/torchtune/_cli/test_download.py index 8a6d6ba0ab..8539aabfbb 100644 --- a/tests/torchtune/_cli/test_download.py +++ b/tests/torchtune/_cli/test_download.py @@ -106,3 +106,178 @@ def test_gated_repo_error_with_token(self, capsys, monkeypatch, snapshot_downloa "Please ensure you have access to the repository and have provided the proper Hugging Face API token" not in out_err.err ) + + # only valid --source parameters supported (expect prompt for supported values) + def test_source_parameter(self, capsys, monkeypatch): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source invalid".split() + monkeypatch.setattr(sys, "argv", testargs) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr() + assert "argument --source: invalid choice: 'invalid'" in output.err + + def test_download_from_kaggle(self, capsys, monkeypatch, mocker, tmpdir): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr().out + assert "Successfully downloaded model repo" in output + + def test_download_from_kaggle_warn_when_output_dir_provided( + self, capsys, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --output-dir /requested/model/path".split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + + with pytest.warns( + UserWarning, + match="--output-dir flag is not supported for Kaggle model downloads", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr().out + assert "Successfully downloaded model repo" in output + + # passes partial credentials with just --kaggle-api-key (expect prompt for all necessary credentials) + def test_download_from_kaggle_partial_credentials_provided( + self, capsys, monkeypatch + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = ( + f"tune download {model} --source kaggle --kaggle-api-key apikey".split() + ) + monkeypatch.setattr(sys, "argv", testargs) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert "Missing --kaggle-username." in out_err.err + assert "Please provide both your Kaggle username and API key." in out_err.err + + testargs = ( + f"tune download {model} --source kaggle --kaggle-username username".split() + ) + monkeypatch.setattr(sys, "argv", testargs) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert "Missing --kaggle-api-key." in out_err.err + assert "Please provide both your Kaggle username and API key." in out_err.err + assert "Find your API key at https://kaggle.com/settings." in out_err.err + + # KaggleApiHTTPError::Unauthorized without --kaggle-username and --kaggle-api-key (expect prompt for credentials) + def test_download_from_kaggle_unauthorized_credentials( + self, capsys, monkeypatch, mocker + ): + from http import HTTPStatus + + from kagglehub.exceptions import KaggleApiHTTPError + + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username username --kaggle-api-key key".split() + monkeypatch.setattr(sys, "argv", testargs) + + mock_model_download = mocker.patch("torchtune._cli.download.model_download") + mock_model_download.side_effect = KaggleApiHTTPError( + "Unauthorized", + response=mocker.MagicMock(status_code=HTTPStatus.UNAUTHORIZED), + ) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert ( + "Please ensure you have access to the model and have provided the proper Kaggle credentials" + in out_err.err + ) + assert "You can also set these to environment variables" in out_err.err + + # KaggleApiHTTPError::NotFound + def test_download_from_kaggle_model_not_found(self, capsys, monkeypatch, mocker): + from http import HTTPStatus + + from kagglehub.exceptions import KaggleApiHTTPError + + model = "mockorganizations/mockmodel/pytorch/mockvariation" + testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split() + monkeypatch.setattr(sys, "argv", testargs) + + mock_model_download = mocker.patch("torchtune._cli.download.model_download") + mock_model_download.side_effect = KaggleApiHTTPError( + "NotFound", response=mocker.MagicMock(status_code=HTTPStatus.NOT_FOUND) + ) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert f"'{model}' not found on the Kaggle Model Hub." in out_err.err + + # KaggleApiHTTPError::InternalServerError + def test_download_from_kaggle_api_error(self, capsys, monkeypatch, mocker): + from http import HTTPStatus + + from kagglehub.exceptions import KaggleApiHTTPError + + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username kaggle_user --kaggle-api-key kaggle_api_key".split() + monkeypatch.setattr(sys, "argv", testargs) + + mock_model_download = mocker.patch("torchtune._cli.download.model_download") + mock_model_download.side_effect = KaggleApiHTTPError( + "InternalError", + response=mocker.MagicMock(status_code=HTTPStatus.INTERNAL_SERVER_ERROR), + ) + + with pytest.raises(SystemExit, match="2"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + out_err = capsys.readouterr() + assert "Failed to download" in out_err.err + + def test_download_from_kaggle_warn_on_nonmeta_pytorch_models( + self, monkeypatch, mocker, tmpdir + ): + model = "kaggle/kaggle-model-name/pytorch/1b" + testargs = f"tune download {model} --source kaggle".split() + monkeypatch.setattr(sys, "argv", testargs) + + # stub out model_download to guarantee success + mocker.patch( + "torchtune._cli.download.model_download", + return_value=tmpdir, + ) + + with pytest.warns(UserWarning, match="may not be compatible with torchtune"): + runpy.run_path(TUNE_PATH, run_name="__main__") + + def test_download_from_kaggle_warn_on_nonpytorch_nontransformers_model( + self, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/some-model/some-madeup-framework/1b" + testargs = f"tune download {model} --source kaggle".split() + monkeypatch.setattr(sys, "argv", testargs) + + # stub out model_download to guarantee success + mocker.patch( + "torchtune._cli.download.model_download", + return_value=tmpdir, + ) + + with pytest.warns(UserWarning, match="may not be compatible with torchtune"): + runpy.run_path(TUNE_PATH, run_name="__main__") From 5cf43f412d25321a0894c4cfc7e7df22c9d4ac30 Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Wed, 6 Nov 2024 19:45:49 +0000 Subject: [PATCH 3/8] docs(cli): add info on kaggle downloads --- docs/source/tune_cli.rst | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/source/tune_cli.rst b/docs/source/tune_cli.rst index 2c432b9894..5981fd8b80 100644 --- a/docs/source/tune_cli.rst +++ b/docs/source/tune_cli.rst @@ -39,19 +39,25 @@ to download files using the CLI. Download a model ---------------- -The ``tune download `` command downloads any model from the Hugging Face Hub. +The ``tune download `` command downloads any model from the Hugging Face or Kaggle Model Hub. .. list-table:: :widths: 30 60 * - \--output-dir - - Directory in which to save the model. + - Directory in which to save the model. Note: this is option not yet supported when `--source` is set to `kaggle`. * - \--output-dir-use-symlinks - To be used with `output-dir`. If set to 'auto', the cache directory will be used and the file will be either duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if already exists) or downloaded from the Hub and not cached. * - \--hf-token - Hugging Face API token. Needed for gated models like Llama. * - \--ignore-patterns - If provided, files matching any of the patterns are not downloaded. Defaults to ignoring safetensors files to avoid downloading duplicate weights. + * - \--source {huggingface,kaggle} + - If provided, downloads model weights from the provided on the designated source hub. + * - \--kaggle-username + - Kaggle username for authentication. Needed for private models or gated models like Llama2. + * - \--kaggle-api-key + - Kaggle API key. Needed for private models or gated models like Llama2. You can find your API key at https://kaggle.com/settings. .. code-block:: bash @@ -62,6 +68,13 @@ The ``tune download `` command downloads any model from the Hugging Face H ./model/model-00001-of-00002.bin ... +.. code-block:: bash + + $ tune download metaresearch/llama-3.2/pytorch/1b --source kaggle + Successfully downloaded model repo and wrote to the following locations: + /tmp/llama-3.2/pytorch/1b/tokenizer.model + /tmp/llama-3.2/pytorch/1b/params.json + /tmp/llama-3.2/pytorch/1b/consolidated.00.pth **Download a gated model** From deda29aaaec7ea14a65beb1480c1ef1210df1b12 Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Thu, 14 Nov 2024 18:18:55 +0000 Subject: [PATCH 4/8] fix: flake8 linting --- torchtune/_cli/download.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index 7fc3b615b4..fdda329c6c 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -62,7 +62,8 @@ def __init__(self, subparsers: argparse._SubParsersAction): /tmp/llama-3.2/pytorch/1b/consolidated.00.pth ... - For a list of all models, visit the Hugging Face Hub https://huggingface.co/models or Kaggle Model Hub https://kaggle.com/models. + For a list of all models, visit the Hugging Face Hub + https://huggingface.co/models or Kaggle Model Hub https://kaggle.com/models. """ ), formatter_class=argparse.RawTextHelpFormatter, @@ -129,7 +130,8 @@ def _add_arguments(self) -> None: "--kaggle-api-key", type=str, required=False, - help="Kaggle API key. Needed for private models or gated models like Llama2. You can find your API key at https://kaggle.com/settings.", + help="Kaggle API key. Needed for private models or gated models like Llama2. You can find your " + "API key at https://kaggle.com/settings.", ) def _download_cmd(self, args: argparse.Namespace) -> None: @@ -264,11 +266,13 @@ def _validate_kaggle_model_handle(self, handle: str) -> None: and parsed_handle.owner != "metaresearch" ): warn( - f"Requested PyTorch model {handle} was not published from Meta, and therefore may not be compatible with torchtune." + f"Requested PyTorch model {handle} was not published from Meta, and therefore " + "may not be compatible with torchtune." ) if parsed_handle.framework not in {"pytorch", "transformers"}: warn( - f"Requested model {handle} is neither a PyTorch nor a Transformers model, and therefore may not be compatible with torchtune." + f"Requested model {handle} is neither a PyTorch nor a Transformers model, and " + "therefore may not be compatible with torchtune." ) except Exception as e: msg = f"Failed to validate {handle} with error {e}." From e748237be2142d851ab5abf80a1958c1fb6d73d4 Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Fri, 15 Nov 2024 22:41:04 +0000 Subject: [PATCH 5/8] fix: add explicit 'argument ignored' warning --- torchtune/_cli/download.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index fdda329c6c..e540d4e67b 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -212,9 +212,13 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None: # kagglehub doesn't currently support `local_dir` and `ignore_patterns` like huggingface_hub if args.output_dir is not None: - warn("--output-dir flag is not supported for Kaggle model downloads.") + warn( + "--output-dir flag is not supported for Kaggle model downloads. " + "This argument will be ignored." + ) if args.kaggle_username is not None and args.kaggle_api_key is not None: + set_kaggle_credentials(args.kaggle_username, args.kaggle_api_key) elif args.kaggle_username is not None and args.kaggle_api_key is None: self._parser.error( From 36993f9a1f2d37cb91f7096a922789268ba265ac Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Mon, 18 Nov 2024 18:56:24 +0000 Subject: [PATCH 6/8] refactor: fallback --kaggle-username and --kaggle-api-key to env vars --- tests/torchtune/_cli/test_download.py | 86 +++++++++++++++++++++------ torchtune/_cli/download.py | 45 +++++++++----- 2 files changed, 99 insertions(+), 32 deletions(-) diff --git a/tests/torchtune/_cli/test_download.py b/tests/torchtune/_cli/test_download.py index 8539aabfbb..3063c78243 100644 --- a/tests/torchtune/_cli/test_download.py +++ b/tests/torchtune/_cli/test_download.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import runpy import sys +from unittest import mock import pytest from tests.common import TUNE_PATH @@ -149,36 +151,84 @@ def test_download_from_kaggle_warn_when_output_dir_provided( output = capsys.readouterr().out assert "Successfully downloaded model repo" in output - # passes partial credentials with just --kaggle-api-key (expect prompt for all necessary credentials) - def test_download_from_kaggle_partial_credentials_provided( - self, capsys, monkeypatch + # tests when --kaggle-username and --kaggle-api-key are provided as CLI args + def test_download_from_kaggle_when_credentials_provided( + self, capsys, monkeypatch, mocker ): + expected_username = "username" + expected_api_key = "api_key" model = "metaresearch/llama-3.2/pytorch/1b" testargs = ( - f"tune download {model} --source kaggle --kaggle-api-key apikey".split() - ) + f"tune download {model} " + f"--source kaggle --kaggle-username {expected_username} " + f"--kaggle-api-key {expected_api_key}" + ).split() monkeypatch.setattr(sys, "argv", testargs) + set_kaggle_credentials_spy = mocker.patch( + "torchtune._cli.download.set_kaggle_credentials" + ) - with pytest.raises(SystemExit, match="2"): - runpy.run_path(TUNE_PATH, run_name="__main__") + runpy.run_path(TUNE_PATH, run_name="__main__") - out_err = capsys.readouterr() - assert "Missing --kaggle-username." in out_err.err - assert "Please provide both your Kaggle username and API key." in out_err.err + set_kaggle_credentials_spy.assert_called_once_with( + expected_username, expected_api_key + ) + output = capsys.readouterr().out + assert ( + "TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key" + in output + ) + assert ( + "For more details, see https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate" + in output + ) - testargs = ( - f"tune download {model} --source kaggle --kaggle-username username".split() + # passes partial credentials with just --kaggle-username (expect fallback to environment variables) + @mock.patch.dict(os.environ, {"KAGGLE_KEY": "env_api_key"}) + def test_download_from_kaggle_when_partial_credentials_provided( + self, capsys, monkeypatch, mocker + ): + expected_username = "username" + expected_api_key = "env_api_key" + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username {expected_username}".split() + monkeypatch.setattr(sys, "argv", testargs) + set_kaggle_credentials_spy = mocker.patch( + "torchtune._cli.download.set_kaggle_credentials" + ) + + runpy.run_path(TUNE_PATH, run_name="__main__") + + set_kaggle_credentials_spy.assert_called_once_with( + expected_username, expected_api_key + ) + output = capsys.readouterr().out + assert ( + "TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key" + in output + ) + assert ( + "For more details, see https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate" + in output ) + + def test_download_from_kaggle_when_set_kaggle_credentials_throws( + self, monkeypatch, mocker + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f"tune download {model} --source kaggle --kaggle-username u --kaggle-api-key k".split() monkeypatch.setattr(sys, "argv", testargs) + mocker.patch( + "torchtune._cli.download.set_kaggle_credentials", + side_effect=Exception("some error"), + ) - with pytest.raises(SystemExit, match="2"): + with pytest.warns( + UserWarning, + match="Failed to set Kaggle credentials with error", + ): runpy.run_path(TUNE_PATH, run_name="__main__") - out_err = capsys.readouterr() - assert "Missing --kaggle-api-key." in out_err.err - assert "Please provide both your Kaggle username and API key." in out_err.err - assert "Find your API key at https://kaggle.com/settings." in out_err.err - # KaggleApiHTTPError::Unauthorized without --kaggle-username and --kaggle-api-key (expect prompt for credentials) def test_download_from_kaggle_unauthorized_credentials( self, capsys, monkeypatch, mocker diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index e540d4e67b..52ad757e1e 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -209,6 +209,7 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None: # Note: Kaggle doesn't actually use the "repository" terminology, but we still reuse args.repo_id here for simplicity model_handle = args.repo_id self._validate_kaggle_model_handle(model_handle) + self._set_kaggle_credentials(args) # kagglehub doesn't currently support `local_dir` and `ignore_patterns` like huggingface_hub if args.output_dir is not None: @@ -217,20 +218,6 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None: "This argument will be ignored." ) - if args.kaggle_username is not None and args.kaggle_api_key is not None: - - set_kaggle_credentials(args.kaggle_username, args.kaggle_api_key) - elif args.kaggle_username is not None and args.kaggle_api_key is None: - self._parser.error( - "Missing --kaggle-api-key. Please provide both your Kaggle username " - "and API key. Find your API key at https://kaggle.com/settings." - ) - elif args.kaggle_username is None and args.kaggle_api_key is not None: - self._parser.error( - "Missing --kaggle-username. Please provide both your Kaggle username " - "and API key." - ) - try: output_dir = model_download(model_handle) print( @@ -281,3 +268,33 @@ def _validate_kaggle_model_handle(self, handle: str) -> None: except Exception as e: msg = f"Failed to validate {handle} with error {e}." self._parser.error(msg) + + def _set_kaggle_credentials(self, args: argparse.Namespace): + try: + if args.kaggle_username or args.kaggle_api_key: + print( + "TIP: you can avoid passing in the --kaggle-username and --kaggle-api-key " + "arguments by storing them as the environment variables KAGGLE_USERNAME and " + "KAGGLE_KEY, respectively. For more details, see " + "https://github.com/Kaggle/kagglehub/blob/main/README.md#authenticate" + ) + + # Fallback to known Kaggle environment variables in case user omits one + # of the CLI arguments. Note, there's no need to fallback when both + # --kaggle-username and --kaggle-api-key are omitted since kagglehub + # will check the environment variables itself. + kaggle_username = ( + args.kaggle_username + if args.kaggle_username + else os.environ.get("KAGGLE_USERNAME") + ) + kaggle_api_key = ( + args.kaggle_api_key + if args.kaggle_api_key + else os.environ.get("KAGGLE_KEY") + ) + set_kaggle_credentials(kaggle_username, kaggle_api_key) + except Exception as e: + msg = f"Failed to set Kaggle credentials with error: '{e}'" + # not all Kaggle downloads require credentials, so there's no need to terminate + warn(msg) From 007d57c7250e8520b32bcf43825bddc3e6a237ef Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Tue, 19 Nov 2024 00:23:19 +0000 Subject: [PATCH 7/8] refactor: add warning that --ignore-patterns is ignored --- tests/torchtune/_cli/test_download.py | 18 ++++++++++++++++++ torchtune/_cli/download.py | 7 ++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/_cli/test_download.py b/tests/torchtune/_cli/test_download.py index 3063c78243..f0643498e6 100644 --- a/tests/torchtune/_cli/test_download.py +++ b/tests/torchtune/_cli/test_download.py @@ -151,6 +151,24 @@ def test_download_from_kaggle_warn_when_output_dir_provided( output = capsys.readouterr().out assert "Successfully downloaded model repo" in output + def test_download_from_kaggle_warn_when_ignore_patterns_provided( + self, capsys, monkeypatch, mocker, tmpdir + ): + model = "metaresearch/llama-3.2/pytorch/1b" + testargs = f'tune download {model} --source kaggle --ignore-patterns "*.glob-pattern"'.split() + monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) + + with pytest.warns( + UserWarning, + match="--ignore-patterns flag is not supported for Kaggle model downloads", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + + output = capsys.readouterr().out + assert "Successfully downloaded model repo" in output + # tests when --kaggle-username and --kaggle-api-key are provided as CLI args def test_download_from_kaggle_when_credentials_provided( self, capsys, monkeypatch, mocker diff --git a/torchtune/_cli/download.py b/torchtune/_cli/download.py index 52ad757e1e..a60cf370ab 100644 --- a/torchtune/_cli/download.py +++ b/torchtune/_cli/download.py @@ -212,11 +212,16 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None: self._set_kaggle_credentials(args) # kagglehub doesn't currently support `local_dir` and `ignore_patterns` like huggingface_hub - if args.output_dir is not None: + if args.output_dir: warn( "--output-dir flag is not supported for Kaggle model downloads. " "This argument will be ignored." ) + if args.ignore_patterns: + warn( + "--ignore-patterns flag is not supported for Kaggle model downloads. " + "This argument will be ignored." + ) try: output_dir = model_download(model_handle) From 00493ef2b493a206b46fa67632962f9e7dce8b99 Mon Sep 17 00:00:00 2001 From: Brandon Horton Date: Tue, 19 Nov 2024 18:39:06 +0000 Subject: [PATCH 8/8] fix(tests): add missing kagglehub.model_download mocks --- tests/torchtune/_cli/test_download.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/torchtune/_cli/test_download.py b/tests/torchtune/_cli/test_download.py index f0643498e6..3338adc870 100644 --- a/tests/torchtune/_cli/test_download.py +++ b/tests/torchtune/_cli/test_download.py @@ -171,17 +171,20 @@ def test_download_from_kaggle_warn_when_ignore_patterns_provided( # tests when --kaggle-username and --kaggle-api-key are provided as CLI args def test_download_from_kaggle_when_credentials_provided( - self, capsys, monkeypatch, mocker + self, capsys, monkeypatch, mocker, tmpdir ): expected_username = "username" expected_api_key = "api_key" model = "metaresearch/llama-3.2/pytorch/1b" testargs = ( f"tune download {model} " - f"--source kaggle --kaggle-username {expected_username} " + f"--source kaggle " + f"--kaggle-username {expected_username} " f"--kaggle-api-key {expected_api_key}" ).split() monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) set_kaggle_credentials_spy = mocker.patch( "torchtune._cli.download.set_kaggle_credentials" ) @@ -204,13 +207,15 @@ def test_download_from_kaggle_when_credentials_provided( # passes partial credentials with just --kaggle-username (expect fallback to environment variables) @mock.patch.dict(os.environ, {"KAGGLE_KEY": "env_api_key"}) def test_download_from_kaggle_when_partial_credentials_provided( - self, capsys, monkeypatch, mocker + self, capsys, monkeypatch, mocker, tmpdir ): expected_username = "username" expected_api_key = "env_api_key" model = "metaresearch/llama-3.2/pytorch/1b" testargs = f"tune download {model} --source kaggle --kaggle-username {expected_username}".split() monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) set_kaggle_credentials_spy = mocker.patch( "torchtune._cli.download.set_kaggle_credentials" ) @@ -231,11 +236,13 @@ def test_download_from_kaggle_when_partial_credentials_provided( ) def test_download_from_kaggle_when_set_kaggle_credentials_throws( - self, monkeypatch, mocker + self, monkeypatch, mocker, tmpdir ): model = "metaresearch/llama-3.2/pytorch/1b" testargs = f"tune download {model} --source kaggle --kaggle-username u --kaggle-api-key k".split() monkeypatch.setattr(sys, "argv", testargs) + # mock out kagglehub.model_download to get around key storage + mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir) mocker.patch( "torchtune._cli.download.set_kaggle_credentials", side_effect=Exception("some error"),