Skip to content

Commit b1162e9

Browse files
committed
new io funcs
1 parent 5cf01e6 commit b1162e9

File tree

2 files changed

+93
-12
lines changed

2 files changed

+93
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.19.4"
3+
version = "0.19.5"
44
description = """\
55
SMASHED is a toolkit designed to apply transformations to samples in \
66
datasets, such as fields extraction, tokenization, prompting, batching, \

src/smashed/utils/io_utils.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434

3535
__all__ = [
3636
"copy_directory",
37+
"exists",
38+
"is_dir",
39+
"is_file",
3740
"open_file_for_read",
3841
"open_file_for_write",
3942
"recursively_list_files",
@@ -274,6 +277,66 @@ def open_file_for_read(
274277
remove_local_file(str(path))
275278

276279

280+
def is_dir(
281+
path: PathType,
282+
client: Optional[ClientType] = None,
283+
raise_if_not_exists: bool = False,
284+
) -> bool:
285+
"""Check if a path is a directory."""
286+
287+
path = MultiPath.parse(path)
288+
client = client or get_client_if_needed(path)
289+
290+
if path.is_local:
291+
if not (e := path.as_path.exists()) and raise_if_not_exists:
292+
raise FileNotFoundError(f"Path does not exist: {path}")
293+
elif not e:
294+
return False
295+
return path.as_path.is_dir()
296+
elif path.is_s3:
297+
assert client is not None, "Could not get S3 client"
298+
resp = client.list_objects_v2(
299+
Bucket=path.bucket, Prefix=path.key.lstrip("/"), Delimiter="/"
300+
)
301+
if "CommonPrefixes" in resp:
302+
return True
303+
elif "Contents" in resp:
304+
return False
305+
elif raise_if_not_exists:
306+
raise FileNotFoundError(f"Path does not exist: {path}")
307+
return False
308+
else:
309+
raise FileNotFoundError(f"Unsupported protocol: {path.prot}")
310+
311+
312+
def is_file(
313+
path: PathType,
314+
client: Optional[ClientType] = None,
315+
raise_if_not_exists: bool = False,
316+
) -> bool:
317+
"""Check if a path is a file."""
318+
319+
try:
320+
return not is_dir(path=path, client=client, raise_if_not_exists=True)
321+
except FileNotFoundError as e:
322+
if raise_if_not_exists:
323+
raise FileNotFoundError(f"Path does not exist: {path}") from e
324+
return False
325+
326+
327+
def exists(
328+
path: PathType,
329+
client: Optional[ClientType] = None,
330+
) -> bool:
331+
"""Check if a path exists"""
332+
333+
try:
334+
is_dir(path=path, client=client, raise_if_not_exists=True)
335+
return True
336+
except FileNotFoundError:
337+
return False
338+
339+
277340
@contextmanager
278341
def open_file_for_write(
279342
path: PathType,
@@ -350,16 +413,24 @@ def open_file_for_write(
350413

351414
def recursively_list_files(
352415
path: PathType,
353-
ignore_hidden_files: bool = True,
416+
ignore_hidden: bool = True,
417+
include_dirs: bool = False,
418+
include_files: bool = True,
354419
client: Optional[ClientType] = None,
355420
) -> Iterable[str]:
356421
"""Recursively list all files in the given directory for a given
357422
path, local or remote.
358423
359424
Args:
360425
path (Union[str, Path, MultiPath]): The path to list content at.
361-
ignore_hidden_files (bool, optional): Whether to ignore hidden files
362-
(i.e. files that start with a dot) when listing. Defaults to True.
426+
ignore_hidden (bool, optional): Whether to ignore hidden files and
427+
directories when listing. Defaults to True.
428+
include_dirs (bool, optional): Whether to include directories in the
429+
listing. Defaults to False.
430+
include_files (bool, optional): Whether to include files in the
431+
listing. Defaults to True.
432+
client (boto3.client, optional): The boto3 client to use. If not
433+
provided, one will be created if necessary.
363434
"""
364435

365436
path = MultiPath.parse(path)
@@ -377,17 +448,27 @@ def recursively_list_files(
377448
for page in pages:
378449
for obj in page["Contents"]:
379450
key = obj["Key"]
380-
if key[-1] == "/": # last char is a slash
451+
path = MultiPath(prot="s3", root=path.root, path=key)
452+
if key[-1] == "/" and key != prefix:
453+
# last char is a slash, so it's a directory
454+
# we don't want to re-include the prefix though, so we
455+
# check that it's not the same
381456
prefixes.append(key)
457+
if include_dirs:
458+
yield str(path)
382459
else:
383-
p = MultiPath(prot="s3", root=path.root, path=key)
384-
yield str(p)
460+
if include_files:
461+
yield str(path)
385462

386463
if path.is_local:
387-
for _root, _, files in local_walk(path.as_str):
464+
for _root, dirnames, filenames in local_walk(path.as_str):
388465
root = Path(_root)
389-
for f in files:
390-
if ignore_hidden_files and f.startswith("."):
466+
to_list = [
467+
*(dirnames if include_dirs else []),
468+
*(filenames if include_files else []),
469+
]
470+
for f in to_list:
471+
if ignore_hidden and f.startswith("."):
391472
continue
392473
yield str(MultiPath.parse(root / f))
393474

@@ -423,7 +504,7 @@ def copy_directory(
423504
client = client or get_client_if_needed(src) or get_client_if_needed(dst)
424505

425506
for sp in recursively_list_files(
426-
path=src, ignore_hidden_files=ignore_hidden_files
507+
path=src, ignore_hidden=ignore_hidden_files
427508
):
428509
# parse the source path
429510
source_path = MultiPath.parse(sp)
@@ -475,7 +556,7 @@ def remove_directory(path: PathType, client: Optional[ClientType] = None):
475556
assert client is not None, "Could not get S3 client"
476557

477558
for fn in recursively_list_files(
478-
path=path, ignore_hidden_files=False, client=client
559+
path=path, ignore_hidden=False, client=client
479560
):
480561
remove_file(fn, client=client)
481562

0 commit comments

Comments
 (0)