Skip to content

Commit cf5c297

Browse files
authored
List files in local mode (#60)
* recurse_fix * added option to follow links
1 parent 376f6b3 commit cf5c297

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.21.0"
3+
version = "0.21.1"
44
description = """\
55
SMASHED is a toolkit designed to apply transformations to samples in \
66
datasets, such as fields extraction, tokenization, prompting, batching, \

src/smashed/utils/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,19 @@
22
from .convert import bytes_from_int, int_from_bytes
33
from .io_utils import (
44
MultiPath,
5+
compress_stream,
6+
copy_directory,
7+
decompress_stream,
8+
exists,
9+
is_dir,
10+
is_file,
511
open_file_for_read,
612
open_file_for_write,
713
recursively_list_files,
14+
remove_directory,
15+
remove_file,
16+
stream_file_for_read,
17+
upload_on_success,
818
)
919
from .version import get_name, get_name_and_version, get_version
1020
from .warnings import SmashedWarnings
@@ -13,15 +23,25 @@
1323
__all__ = [
1424
"BlingFireSplitter",
1525
"bytes_from_int",
26+
"compress_stream",
27+
"copy_directory",
28+
"decompress_stream",
29+
"exists",
1630
"get_cache_dir",
1731
"get_name_and_version",
1832
"get_name",
1933
"get_version",
2034
"int_from_bytes",
35+
"is_dir",
36+
"is_file",
2137
"MultiPath",
2238
"open_file_for_read",
2339
"open_file_for_write",
2440
"recursively_list_files",
41+
"remove_directory",
42+
"remove_file",
2543
"SmashedWarnings",
44+
"stream_file_for_read",
45+
"upload_on_success",
2646
"WhitespaceSplitter",
2747
]

src/smashed/utils/io_utils/operations.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def recursively_list_files(
329329
include_dirs: bool = False,
330330
include_files: bool = True,
331331
client: Optional[ClientType] = None,
332+
local_follow_links: bool = False,
332333
) -> Iterable[str]:
333334
"""Recursively list all files in the given directory for a given
334335
path, local or remote.
@@ -343,6 +344,8 @@ def recursively_list_files(
343344
listing. Defaults to True.
344345
client (boto3.client, optional): The boto3 client to use. If not
345346
provided, one will be created if necessary.
347+
local_follow_links (bool, optional): Whether to follow symlinks when
348+
listing local files. Defaults to False.
346349
"""
347350

348351
path = MultiPath.parse(path)
@@ -373,7 +376,14 @@ def recursively_list_files(
373376
yield str(path)
374377

375378
if path.is_local:
376-
for _root, dirnames, filenames in local_walk(path.as_str):
379+
if not path.as_path.is_dir():
380+
# yield the path itself if it's not a directory; this matches
381+
# the behavior that we get for S3.
382+
yield path.as_str
383+
384+
for _root, dirnames, filenames in local_walk(
385+
top=path.as_str, followlinks=local_follow_links
386+
):
377387
root = Path(_root)
378388
to_list = [
379389
*(dirnames if include_dirs else []),

0 commit comments

Comments
 (0)