34
34
35
35
__all__ = [
36
36
"copy_directory" ,
37
+ "exists" ,
38
+ "is_dir" ,
39
+ "is_file" ,
37
40
"open_file_for_read" ,
38
41
"open_file_for_write" ,
39
42
"recursively_list_files" ,
@@ -274,6 +277,66 @@ def open_file_for_read(
274
277
remove_local_file (str (path ))
275
278
276
279
280
+ def is_dir (
281
+ path : PathType ,
282
+ client : Optional [ClientType ] = None ,
283
+ raise_if_not_exists : bool = False ,
284
+ ) -> bool :
285
+ """Check if a path is a directory."""
286
+
287
+ path = MultiPath .parse (path )
288
+ client = client or get_client_if_needed (path )
289
+
290
+ if path .is_local :
291
+ if not (e := path .as_path .exists ()) and raise_if_not_exists :
292
+ raise FileNotFoundError (f"Path does not exist: { path } " )
293
+ elif not e :
294
+ return False
295
+ return path .as_path .is_dir ()
296
+ elif path .is_s3 :
297
+ assert client is not None , "Could not get S3 client"
298
+ resp = client .list_objects_v2 (
299
+ Bucket = path .bucket , Prefix = path .key .lstrip ("/" ), Delimiter = "/"
300
+ )
301
+ if "CommonPrefixes" in resp :
302
+ return True
303
+ elif "Contents" in resp :
304
+ return False
305
+ elif raise_if_not_exists :
306
+ raise FileNotFoundError (f"Path does not exist: { path } " )
307
+ return False
308
+ else :
309
+ raise FileNotFoundError (f"Unsupported protocol: { path .prot } " )
310
+
311
+
312
+ def is_file (
313
+ path : PathType ,
314
+ client : Optional [ClientType ] = None ,
315
+ raise_if_not_exists : bool = False ,
316
+ ) -> bool :
317
+ """Check if a path is a file."""
318
+
319
+ try :
320
+ return not is_dir (path = path , client = client , raise_if_not_exists = True )
321
+ except FileNotFoundError as e :
322
+ if raise_if_not_exists :
323
+ raise FileNotFoundError (f"Path does not exist: { path } " ) from e
324
+ return False
325
+
326
+
327
+ def exists (
328
+ path : PathType ,
329
+ client : Optional [ClientType ] = None ,
330
+ ) -> bool :
331
+ """Check if a path exists"""
332
+
333
+ try :
334
+ is_dir (path = path , client = client , raise_if_not_exists = True )
335
+ return True
336
+ except FileNotFoundError :
337
+ return False
338
+
339
+
277
340
@contextmanager
278
341
def open_file_for_write (
279
342
path : PathType ,
@@ -350,16 +413,24 @@ def open_file_for_write(
350
413
351
414
def recursively_list_files (
352
415
path : PathType ,
353
- ignore_hidden_files : bool = True ,
416
+ ignore_hidden : bool = True ,
417
+ include_dirs : bool = False ,
418
+ include_files : bool = True ,
354
419
client : Optional [ClientType ] = None ,
355
420
) -> Iterable [str ]:
356
421
"""Recursively list all files in the given directory for a given
357
422
path, local or remote.
358
423
359
424
Args:
360
425
path (Union[str, Path, MultiPath]): The path to list content at.
361
- ignore_hidden_files (bool, optional): Whether to ignore hidden files
362
- (i.e. files that start with a dot) when listing. Defaults to True.
426
+ ignore_hidden (bool, optional): Whether to ignore hidden files and
427
+ directories when listing. Defaults to True.
428
+ include_dirs (bool, optional): Whether to include directories in the
429
+ listing. Defaults to False.
430
+ include_files (bool, optional): Whether to include files in the
431
+ listing. Defaults to True.
432
+ client (boto3.client, optional): The boto3 client to use. If not
433
+ provided, one will be created if necessary.
363
434
"""
364
435
365
436
path = MultiPath .parse (path )
@@ -377,17 +448,27 @@ def recursively_list_files(
377
448
for page in pages :
378
449
for obj in page ["Contents" ]:
379
450
key = obj ["Key" ]
380
- if key [- 1 ] == "/" : # last char is a slash
451
+ path = MultiPath (prot = "s3" , root = path .root , path = key )
452
+ if key [- 1 ] == "/" and key != prefix :
453
+ # last char is a slash, so it's a directory
454
+ # we don't want to re-include the prefix though, so we
455
+ # check that it's not the same
381
456
prefixes .append (key )
457
+ if include_dirs :
458
+ yield str (path )
382
459
else :
383
- p = MultiPath ( prot = "s3" , root = path . root , path = key )
384
- yield str (p )
460
+ if include_files :
461
+ yield str (path )
385
462
386
463
if path .is_local :
387
- for _root , _ , files in local_walk (path .as_str ):
464
+ for _root , dirnames , filenames in local_walk (path .as_str ):
388
465
root = Path (_root )
389
- for f in files :
390
- if ignore_hidden_files and f .startswith ("." ):
466
+ to_list = [
467
+ * (dirnames if include_dirs else []),
468
+ * (filenames if include_files else []),
469
+ ]
470
+ for f in to_list :
471
+ if ignore_hidden and f .startswith ("." ):
391
472
continue
392
473
yield str (MultiPath .parse (root / f ))
393
474
@@ -423,7 +504,7 @@ def copy_directory(
423
504
client = client or get_client_if_needed (src ) or get_client_if_needed (dst )
424
505
425
506
for sp in recursively_list_files (
426
- path = src , ignore_hidden_files = ignore_hidden_files
507
+ path = src , ignore_hidden = ignore_hidden_files
427
508
):
428
509
# parse the source path
429
510
source_path = MultiPath .parse (sp )
@@ -475,7 +556,7 @@ def remove_directory(path: PathType, client: Optional[ClientType] = None):
475
556
assert client is not None , "Could not get S3 client"
476
557
477
558
for fn in recursively_list_files (
478
- path = path , ignore_hidden_files = False , client = client
559
+ path = path , ignore_hidden = False , client = client
479
560
):
480
561
remove_file (fn , client = client )
481
562
0 commit comments