Skip to content

Commit ffa9ff7

Browse files
authored
feat(code): add function to download notebook outputs (#184)
BUG=b/371574828 CHILD=#185 BLOCKED_BY=go/kaggle-pr/32581,go/kaggle-pr/32632 CC=@rosbo,@dster2,@jplotts Extends the same functionality from [`models.py`](https://github.com/Kaggle/kagglehub/blob/8b1fae8632f9d381cebb14ec50c56d7ff5fbeb1b/src/kagglehub/models.py), [`datasets.py`](https://github.com/Kaggle/kagglehub/blob/8b1fae8632f9d381cebb14ec50c56d7ff5fbeb1b/src/kagglehub/datasets.py) and [`competition.py`](https://github.com/Kaggle/kagglehub/blob/8b1fae8632f9d381cebb14ec50c56d7ff5fbeb1b/src/kagglehub/competition.py) to the notebooks at https://kaggle.com/code. ### Changes [handle.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-4d86981fdd4a6e41ce621dd3dcafa482e4ca96d5cca7da9b6cbaff147cc3bffb) - added a new `*Handle` data type - using `Code` to align with the route on the main site https://kaggle.com/code [cache.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-3468a9a96dc65a3a8770b887cdc452e4975b9c934f0376c56a7e39dff7fd778a) - added new functions to dictate the cached path for notebook outputs based on the properties in `CodeHandle` - this mostly mirrors the same structure as the `model`, `dataset`, and `competition` paths. - the cache structure is split into the _output\_path_, _archive\_path_, a _completion\_file\_marker\_path for individual files in the download payload, and a _completion\_file\_marker\_path for the entire download payload. - the structure is as follows: ``` <cache_root>/ └── notebooks/ └── username/ └── notebook_slug/ ├── output.complete <-- tracker for the entire output ├── .complete/ <-- trackers for per file within the output │ └── output/ │ ├── file1.txt.complete │ ├── file2.txt.complete ├── output.archive <-- the compressed output (.tar.gz or .zip) └── output/ <-- the uncompressed output ├── file1.txt └── file2.txt ``` [http_resolver.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-bac8b2fc0706f6a2f83562279c3095d907c84f377a4f49829bb06935d3e6773a) - Implemented the `NotebookOutputHttpResolver` - Note, we don't currently have an API endpoint to download notebook output in a kagglehub-compatible compression format (left a TODO with our internal tracker for this) - It leverages the our existing `KaggleApiV1Client` + the new cache location mentioned above [registry.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-3a93c8bc0f26d8267eb445e7e90f03b2c85c54e8c59e37f997c82104cb5d1541) + [\_\_init\_\_.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-214a9613a5b623f57fb158c6c784e29e051f95fe72694b446cae592f76b825d8) - bootstraps the `NotebookOutputHttpResolver` so that it can be called by `kagglehub.notebook_output_download` in `code.py` [code.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-e0c454ef6e8643b0efb25013db58209adc53499f2a5a485a8b5bd63cd280899a) - the entry point to the notebook output downloading functionality - the file is named `code.py` to align with our navigation paths at https://kaggle.com (similar to `models`, `datasets`, and `competitions`). Open to changing if needed. - the function is named `notebook_output_download` to be more specific about what's being downoaded [test_notebook_output_download.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-04eda2a8f4c98b098d47a0dc61b70c4bfce3a2fba4205abcb9830c69a05001e3) - integration tests for the new `kagglehub.notebook_output_download` function - TODO(#185): adding tests in a followup since that requires propping up a [stubbed API server](https://github.com/Kaggle/kagglehub/tree/8b1fae8632f9d381cebb14ec50c56d7ff5fbeb1b/tests/server_stubs). Trying to keep this diff from getting any bigger. [gcs_upload.py](https://github.com/Kaggle/kagglehub/pull/184/files#diff-2dcd4fa7b008e12dc0d76354035bc2cdaf48dd2422e997476443a4fbe550d8ea) - A miscellaneous lint error that slipped through. Fixing here as a drive-by change as per [this comment](#184 (comment))
1 parent 8b1fae8 commit ffa9ff7

File tree

9 files changed

+222
-6
lines changed

9 files changed

+222
-6
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
3+
from requests import HTTPError
4+
5+
from kagglehub import notebook_output_download
6+
7+
from .utils import assert_files, create_test_cache, unauthenticated
8+
9+
10+
class TestModelDownload(unittest.TestCase):
11+
def test_download_notebook_output_succeeds(self) -> None:
12+
with create_test_cache():
13+
actual_path = notebook_output_download("alexisbcook/titanic-tutorial")
14+
15+
expected_files = ["submission.csv"]
16+
assert_files(self, actual_path, expected_files)
17+
18+
def test_download_public_notebook_output_as_unauthenticated_succeeds(self) -> None:
19+
with create_test_cache():
20+
with unauthenticated():
21+
actual_path = notebook_output_download("alexisbcook/titanic-tutorial")
22+
23+
expected_files = ["submission.csv"]
24+
assert_files(self, actual_path, expected_files)
25+
26+
def test_download_private_notebook_output_succeeds(self) -> None:
27+
with create_test_cache():
28+
actual_path = notebook_output_download("integrationtester/private-titanic-tutorial")
29+
30+
expected_files = ["submission-01.csv", "submission-02.csv"]
31+
32+
assert_files(self, actual_path, expected_files)
33+
34+
def test_download_private_notebook_output_single_file_succeeds(self) -> None:
35+
with create_test_cache():
36+
actual_path = notebook_output_download(
37+
"integrationtester/private-titanic-tutorial", path="submission-02.csv"
38+
)
39+
40+
expected_files = ["submission-02.csv"]
41+
42+
assert_files(self, actual_path, expected_files)
43+
44+
def test_download_large_notebook_output_warns(self) -> None:
45+
handle = "integrationtester/titanic-tutorial-many-output-files"
46+
with create_test_cache():
47+
# If the model has > 25 files, we warn the user that it's not supported yet
48+
# TODO(b/379761520): add support for .tar.gz archived downloads
49+
notebook_output_download(handle)
50+
msg = f"Too many files in {handle} (capped at 25). Unable to download notebook output."
51+
self.assertLogs(msg, "WARNING")
52+
53+
def test_download_private_notebook_output_with_incorrect_file_path_fails(self) -> None:
54+
with create_test_cache(), self.assertRaises(HTTPError):
55+
notebook_output_download("integrationtester/titanic-tutorial", path="submission-03.csv")

src/kagglehub/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from kagglehub.competition import competition_download
77
from kagglehub.datasets import dataset_download, dataset_upload
88
from kagglehub.models import model_download, model_upload
9+
from kagglehub.notebooks import notebook_output_download
910

1011
registry.model_resolver.add_implementation(http_resolver.ModelHttpResolver())
1112
registry.model_resolver.add_implementation(kaggle_cache_resolver.ModelKaggleCacheResolver())
@@ -17,3 +18,6 @@
1718

1819
registry.competition_resolver.add_implementation(http_resolver.CompetitionHttpResolver())
1920
registry.competition_resolver.add_implementation(kaggle_cache_resolver.CompetitionKaggleCacheResolver())
21+
22+
# TODO(b/380340624): implement a kaggle_cache_resolver for notebook outputs
23+
registry.notebook_output_resolver.add_implementation(http_resolver.NotebookOutputHttpResolver())

src/kagglehub/cache.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from typing import Optional
55

66
from kagglehub.config import get_cache_folder
7-
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, ResourceHandle
7+
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle
88

99
DATASETS_CACHE_SUBFOLDER = "datasets"
10+
NOTEBOOKS_CACHE_SUBFOLDER = "notebooks" # for resources under kaggle.com/code
1011
COMPETITIONS_CACHE_SUBFOLDER = "competitions"
1112
MODELS_CACHE_SUBFOLDER = "models"
1213
FILE_COMPLETION_MARKER_FOLDER = ".complete"
@@ -35,6 +36,8 @@ def get_cached_path(handle: ResourceHandle, path: Optional[str] = None) -> str:
3536
return _get_dataset_path(handle, path)
3637
elif isinstance(handle, CompetitionHandle):
3738
return _get_competition_path(handle, path)
39+
elif isinstance(handle, NotebookHandle):
40+
return _get_notebook_output_path(handle, path)
3841
else:
3942
msg = "Invalid handle"
4043
raise ValueError(msg)
@@ -47,6 +50,8 @@ def get_cached_archive_path(handle: ResourceHandle) -> str:
4750
return _get_dataset_archive_path(handle)
4851
elif isinstance(handle, CompetitionHandle):
4952
return _get_competition_archive_path(handle)
53+
elif isinstance(handle, NotebookHandle):
54+
return _get_notebook_output_archive_path(handle)
5055
else:
5156
msg = "Invalid handle"
5257
raise ValueError(msg)
@@ -105,6 +110,8 @@ def _get_completion_marker_filepath(handle: ResourceHandle, path: Optional[str]
105110
return _get_datasets_completion_marker_filepath(handle, path)
106111
elif isinstance(handle, CompetitionHandle):
107112
return _get_competitions_completion_marker_filepath(handle, path)
113+
elif isinstance(handle, NotebookHandle):
114+
return _get_notebook_output_completion_marker_filepath(handle, path)
108115
else:
109116
msg = "Invalid handle"
110117
raise ValueError(msg)
@@ -118,6 +125,11 @@ def _get_dataset_path(handle: DatasetHandle, path: Optional[str] = None) -> str:
118125
return os.path.join(base_path, path) if path else base_path
119126

120127

128+
def _get_notebook_output_path(handle: NotebookHandle, path: Optional[str] = None) -> str:
129+
base_path = os.path.join(get_cache_folder(), NOTEBOOKS_CACHE_SUBFOLDER, handle.owner, handle.notebook, "output")
130+
return os.path.join(base_path, path) if path else base_path
131+
132+
121133
def _get_competition_path(handle: CompetitionHandle, path: Optional[str] = None) -> str:
122134
base_path = os.path.join(get_cache_folder(), COMPETITIONS_CACHE_SUBFOLDER, handle.competition)
123135
return os.path.join(base_path, path) if path else base_path
@@ -167,6 +179,10 @@ def _get_competition_archive_path(handle: CompetitionHandle) -> str:
167179
)
168180

169181

182+
def _get_notebook_output_archive_path(handle: NotebookHandle) -> str:
183+
return os.path.join(get_cache_folder(), NOTEBOOKS_CACHE_SUBFOLDER, handle.owner, handle.notebook, "output.archive")
184+
185+
170186
def _get_models_completion_marker_filepath(handle: ModelHandle, path: Optional[str] = None) -> str:
171187
if path:
172188
return os.path.join(
@@ -213,6 +229,20 @@ def _get_datasets_completion_marker_filepath(handle: DatasetHandle, path: Option
213229
)
214230

215231

232+
def _get_notebook_output_completion_marker_filepath(handle: NotebookHandle, path: Optional[str] = None) -> str:
233+
if path:
234+
return os.path.join(
235+
get_cache_folder(),
236+
NOTEBOOKS_CACHE_SUBFOLDER,
237+
handle.owner,
238+
handle.notebook,
239+
FILE_COMPLETION_MARKER_FOLDER,
240+
"output",
241+
f"{path}.complete",
242+
)
243+
return os.path.join(get_cache_folder(), NOTEBOOKS_CACHE_SUBFOLDER, handle.owner, handle.notebook, "output.complete")
244+
245+
216246
def _get_competitions_completion_marker_filepath(handle: CompetitionHandle, path: Optional[str] = None) -> str:
217247
if path:
218248
return os.path.join(

src/kagglehub/gcs_upload.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Optional, Union
1111

1212
import requests
13-
from requests.exceptions import ConnectionError, Timeout
13+
from requests.exceptions import Timeout
1414
from tqdm import tqdm
1515
from tqdm.utils import CallbackIOWrapper
1616

@@ -66,7 +66,7 @@ def get_size(size: float, precision: int = 0) -> str:
6666
while size >= 1024 and suffix_index < 4: # noqa: PLR2004
6767
suffix_index += 1
6868
size /= 1024.0
69-
return "%.*f%s" % (precision, size, suffixes[suffix_index])
69+
return f"{size:.{precision}f}{suffixes[suffix_index]}"
7070

7171

7272
def filtered_walk(*, base_dir: str, ignore_patterns: Sequence[str]) -> Iterable[tuple[str, list[str], list[str]]]:
@@ -109,7 +109,7 @@ def _check_uploaded_size(session_uri: str, file_size: int, backoff_factor: int =
109109
return 0 # If no Range header, assume no bytes were uploaded
110110
else:
111111
return file_size
112-
except (ConnectionError, Timeout):
112+
except (requests.ConnectionError, Timeout):
113113
logger.info(f"Network issue while checking uploaded size, retrying in {backoff_factor} seconds...")
114114
time.sleep(backoff_factor)
115115
backoff_factor = min(backoff_factor * 2, 60)

src/kagglehub/handle.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
NUM_VERSIONED_MODEL_PARTS = 5 # e.g.: <owner>/<model>/<framework>/<variation>/<version>
1313
NUM_UNVERSIONED_MODEL_PARTS = 4 # e.g.: <owner>/<model>/<framework>/<variation>
1414

15+
NUM_UNVERSIONED_NOTEBOOK_PARTS = 2 # e.g.: <owner>/<notebook>
16+
1517

1618
@dataclass
1719
class ResourceHandle:
@@ -83,6 +85,21 @@ def to_url(self) -> str:
8385
return base_url
8486

8587

88+
@dataclass
89+
class NotebookHandle(ResourceHandle):
90+
owner: str
91+
notebook: str
92+
93+
def __str__(self) -> str:
94+
handle_str = f"{self.owner}/{self.notebook}"
95+
return handle_str
96+
97+
def to_url(self) -> str:
98+
endpoint = get_kaggle_api_endpoint()
99+
base_url = f"{endpoint}/code/{self.owner}/{self.notebook}"
100+
return base_url
101+
102+
86103
def parse_dataset_handle(handle: str) -> DatasetHandle:
87104
parts = handle.split("/")
88105

@@ -152,3 +169,11 @@ def parse_competition_handle(handle: str) -> CompetitionHandle:
152169
raise ValueError(msg)
153170

154171
return CompetitionHandle(competition=handle)
172+
173+
174+
def parse_notebook_handle(handle: str) -> NotebookHandle:
175+
parts = handle.split("/")
176+
if len(parts) != NUM_UNVERSIONED_NOTEBOOK_PARTS:
177+
msg = f"Invalid notebook handle: {handle}"
178+
raise ValueError(msg)
179+
return NotebookHandle(owner=parts[0], notebook=parts[1])

src/kagglehub/http_resolver.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import tarfile
44
import zipfile
5+
from pathlib import Path
56
from typing import Optional
67

78
import requests
@@ -16,7 +17,7 @@
1617
)
1718
from kagglehub.clients import KaggleApiV1Client
1819
from kagglehub.exceptions import UnauthenticatedError
19-
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, ResourceHandle
20+
from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle
2021
from kagglehub.resolver import Resolver
2122

2223
DATASET_CURRENT_VERSION_FIELD = "currentVersionNumber"
@@ -199,6 +200,68 @@ def _inner_download_file(file: str) -> None:
199200
return out_path
200201

201202

203+
class NotebookOutputHttpResolver(Resolver[NotebookHandle]):
204+
def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003
205+
# Downloading files over HTTP is supported in all environments for all handles / paths.
206+
return True
207+
208+
def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
209+
api_client = KaggleApiV1Client()
210+
211+
cached_response = load_from_cache(h, path)
212+
if cached_response and not force_download:
213+
return cached_response # Already cached
214+
elif cached_response and force_download:
215+
delete_from_cache(h, path)
216+
217+
download_url_root = f"kernels/output/download/{h.owner}/{h.notebook}"
218+
output_root = Path(get_cached_path(h, path))
219+
220+
# List the files and decide how to download them:
221+
# - <= 25 files: Download files in parallel
222+
# > 25 files: Download the archive and uncompress
223+
(files, has_more) = self._list_files(api_client, h) if not path else ([path], False)
224+
if has_more:
225+
# TODO(b/379761520): add support for .tar.gz archived downloads
226+
logger.warning(
227+
f"Too many files in {h} (capped at {MAX_NUM_FILES_DIRECT_DOWNLOAD}). "
228+
"Unable to download notebook output."
229+
)
230+
return ""
231+
232+
# Download files individually in parallel
233+
def _inner_download_file(filepath: str) -> None:
234+
download_url_path = f"{download_url_root}/{filepath}"
235+
full_output_filepath = output_root / filepath
236+
237+
os.makedirs(os.path.dirname(full_output_filepath), exist_ok=True)
238+
api_client.download_file(download_url_path, str(full_output_filepath), h)
239+
240+
thread_map(
241+
_inner_download_file,
242+
files,
243+
desc=f"Downloading {len(files)} files",
244+
max_workers=8, # Never use more than 8 threads in parallel to download files.
245+
)
246+
247+
mark_as_complete(h, path)
248+
249+
# TODO(b/377510971): when notebook is a Kaggle utility script, update sys.path
250+
return str(output_root)
251+
252+
def _list_files(self, api_client: KaggleApiV1Client, h: NotebookHandle) -> tuple[list[str], bool]:
253+
query = f"kernels/output/list/{h.owner}/{h.notebook}?page_size={MAX_NUM_FILES_DIRECT_DOWNLOAD}"
254+
json_response = api_client.get(query, h)
255+
if "files" not in json_response:
256+
msg = "Invalid ApiListKernelSessionOutput API response. Expected to include a 'files' field"
257+
raise ValueError(msg)
258+
259+
files = [f["fileName"].lstrip("/") for f in json_response["files"]]
260+
has_more = "nextPageToken" in json_response and json_response["nextPageToken"] != ""
261+
262+
return (files, has_more)
263+
264+
202265
def _extract_archive(archive_path: str, out_path: str) -> None:
203266
logger.info("Extracting files...")
204267
if tarfile.is_tarfile(archive_path):

src/kagglehub/notebooks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import logging
2+
from typing import Optional
3+
4+
from kagglehub import registry
5+
from kagglehub.handle import parse_notebook_handle
6+
from kagglehub.logger import EXTRA_CONSOLE_BLOCK
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def notebook_output_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
12+
"""[WORK IN PROGRESS]
13+
14+
Download notebook output files.
15+
16+
Args:
17+
handle: (string) the notebook handle under https://kaggle.com/code.
18+
path: (string) Optional path to a file within the notebook output.
19+
force_download: (bool) Optional flag to force download motebook output, even if it's cached.
20+
21+
22+
Returns:
23+
A string representing the path to the requested notebook output files.
24+
"""
25+
h = parse_notebook_handle(handle)
26+
logger.info(f"Downloading Notebook Output: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
27+
return registry.notebook_output_resolver(h, path, force_download=force_download)

src/kagglehub/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def __call__(self, *args, **kwargs): # noqa: ANN002, ANN003
3131
model_resolver = MultiImplRegistry("ModelResolver")
3232
dataset_resolver = MultiImplRegistry("DatasetResolver")
3333
competition_resolver = MultiImplRegistry("CompetitionResolver")
34+
notebook_output_resolver = MultiImplRegistry("NotebookOutputResolver")

tests/test_handle.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from kagglehub.handle import parse_competition_handle, parse_dataset_handle, parse_model_handle
1+
from kagglehub.handle import parse_competition_handle, parse_dataset_handle, parse_model_handle, parse_notebook_handle
22
from tests.fixtures import BaseTestCase
33

44

@@ -68,3 +68,14 @@ def test_competition_handle(self) -> None:
6868
h = parse_competition_handle(handle)
6969

7070
self.assertEqual("titanic", h.competition)
71+
72+
def test_code_handle(self) -> None:
73+
handle = "owner/notebook"
74+
h = parse_notebook_handle(handle)
75+
76+
self.assertEqual("owner", h.owner)
77+
self.assertEqual("notebook", h.notebook)
78+
79+
def test_invalid_code_handle(self) -> None:
80+
with self.assertRaises(ValueError):
81+
parse_notebook_handle("notebook")

0 commit comments

Comments
 (0)