-
Couldn't load subscription status.
- Fork 86
Description
Proposal: Adding output_dir to ModelHttpResolver for Direct Downloading
This proposal outlines a plan to add an output_dir argument to the __call__ method of the ModelHttpResolver class. This new argument will allow users to download models directly to a specified directory, bypassing the default caching mechanism and using a flat directory structure.
Motivation:
This feature is primarily driven by user-feedback (#175) and the need for better integration between kagglehub and other libraries, specifically torchtune (see meta-pytorch/torchtune#1852).
NOTE:
torchtunecurrently useshuggingface_hubfor model downloading, leveraging itssnapshot_downloadfunction and itslocal_dirargument for specifying custom download locations.
Proposed Changes:
-
Introduce a
Cacheclass:This class will encapsulate the caching logic from
cache.py, providing a cleaner interface and enabling the flexible directory structure based on theoverride_dirargument.class Cache: def __init__(self, override_dir: Optional[str] = None): self.override_dir = override_dir def get_path(self, handle: ResourceHandle, path: Optional[str] = None) -> str: if self.override_dir: return os.path.join(self.override_dir, path or "") # Flat structure else: return get_cached_path(handle, path) # Existing nested structure def get_archive_path(self, handle: ResourceHandle) -> str: if self.override_dir: return os.path.join(self.override_dir, f"{handle.version!s}.archive") else: return get_cached_archive_path(handle) def get_completion_marker_filepath(self, handle: ResourceHandle, path: Optional[str] = None) -> str: if self.override_dir: return os.path.join(self.override_dir, f"{path or handle.version!s}.complete") else: return _get_completion_marker_filepath(handle, path) def load_from_cache(self, handle: ResourceHandle, path: Optional[str] = None) -> Optional[str]: """Return path for the requested resource from the cache or output_dir.""" marker_path = self.get_completion_marker_filepath(handle, path) full_path = self.get_path(handle, path) return full_path if os.path.exists(marker_path) and os.path.exists(full_path) else None # Migrate other relevant methods from cache.py (e.g., mark_as_complete, delete_from_cache) to the Cache class, adapting them similarly.
-
Modify
ModelHttpResolver.__call__:Add the
output_dirargument and instantiate theCacheclass based on its value.def __call__( self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False, output_dir: Optional[str] = None ) -> str: cache = Cache(override_dir=output_dir) # ...use cache.get_path(), cache.get_archive_path(), cache.load_from_cache(), and cache.get_completion_marker_filepath()...
-
Pass
output_dirthroughmodel_download:def model_download( # ...existing arguments output_dir: Optional[str] = None, ): # Existing code... return registry.model_resolver(h, path, force_download=force_download, output_dir=output_dir)
-
Handle
output_dirin other Resolvers:Add warning logs to other resolvers to inform users that the
output_dirargument is not supported and will be ignored.
Advantages:
- Flexibility: Users can choose between using the cache or downloading directly to a specific directory. This also allows users to bypass the deeply nested subdirectory structure in the default cache.
- Improved Code Organization: Encapsulating caching logic within the
Cacheclass discourages use of global variables, promoting better maintainability and extensibility.
Request for Feedback:
This proposal outlines one possible implementation. Any feedback on alternative implementations is appreciated. Some other passing thoughts include:
Cacheclass design: Is this the best way to encapsulate the caching logic? Are there alternative approaches?- Flat directory structure: Is this the most user-friendly approach for direct downloads? Should we consider other directory structures?
- Error handling: How should errors related to
output_dir(e.g., permissions issues) be handled? - Interaction with existing cache: If
output_diris specified, should we completely bypass the cache or should we consider using it as a secondary storage location? How might we handle cache invalidation in such a case? - Naming: I chose
output_dirto align withtorchtune. Butdownload_pathwas suggested in Feature Request: Add Support for Custom Path in kagglehub.model_download #175. This may or may not make the intention clearer.