Skip to content

Commit ba6291f

Browse files
authored
io utilities (#43)
1 parent b3cc218 commit ba6291f

File tree

3 files changed

+445
-1
lines changed

3 files changed

+445
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ dev = [
8383
"Flake8-pyproject>=1.1.0",
8484
]
8585
remote = [
86-
"smart-open>=5.2.1"
86+
"smart-open>=5.2.1",
87+
"boto3>=1.25.5",
8788
]
8889
datasets = [
8990
"datasets>=2.4.0",

src/smashed/utils/io_utils.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import os
2+
import shutil
3+
from contextlib import contextmanager
4+
from logging import INFO, Logger, getLogger
5+
from pathlib import Path
6+
from tempfile import NamedTemporaryFile
7+
from typing import (
8+
IO,
9+
TYPE_CHECKING,
10+
Any,
11+
Callable,
12+
Dict,
13+
Generator,
14+
Iterable,
15+
Optional,
16+
Union,
17+
)
18+
from urllib.parse import urlparse
19+
20+
from necessary import Necessary, necessary
21+
22+
with necessary("boto3", soft=True) as BOTO_AVAILABLE:
23+
if TYPE_CHECKING or BOTO_AVAILABLE:
24+
import boto3
25+
26+
27+
__all__ = [
28+
'open_file_for_read',
29+
'open_file_for_write',
30+
'recursively_list_files',
31+
'remove_directory',
32+
]
33+
34+
35+
def get_logger() -> Logger:
36+
"""Get the default logger for this module."""
37+
(logger := getLogger(__file__)).setLevel(INFO)
38+
return logger
39+
40+
41+
@Necessary("boto3")
42+
@contextmanager
43+
def open_file_for_read(
44+
path: Union[str, Path],
45+
mode: str = "r",
46+
open_fn: Optional[Callable] = None,
47+
logger: Optional[Logger] = None,
48+
open_kwargs: Optional[Dict[str, Any]] = None,
49+
) -> Generator[IO, None, None]:
50+
"""Get a context manager to read in a file that is either on
51+
S3 or local.
52+
53+
Args:
54+
path (Union[str, Path]): The path to the file to read. Can be an S3
55+
or local path.
56+
mode (str, optional): The mode to open the file in. Defaults to "r".
57+
Only read modes are supported (e.g. 'rb', 'rt', 'r').
58+
open_fn (Callable, optional): The function to use to open the file.
59+
Defaults to the built-in open function.
60+
logger (Logger, optional): The logger to use. Defaults to the built-in
61+
logger at INFO level.
62+
open_kwargs (Dict[str, Any], optional): Any additional keyword to pass
63+
to the open function. Defaults to None.
64+
"""
65+
open_kwargs = open_kwargs or {}
66+
logger = logger or get_logger()
67+
open_fn = open_fn or open
68+
parse = urlparse(str(path))
69+
remove = False
70+
71+
assert "r" in mode, "Only read mode is supported"
72+
73+
if parse.scheme == "s3":
74+
client = boto3.client("s3")
75+
logger.info(f"Downloading {path} to a temporary file")
76+
with NamedTemporaryFile(delete=False) as f:
77+
path = f.name
78+
client.download_fileobj(parse.netloc, parse.path.lstrip("/"), f)
79+
remove = True
80+
elif parse.scheme == "file" or parse.scheme == "":
81+
pass
82+
else:
83+
raise ValueError(f"Unsupported scheme {parse.scheme}")
84+
85+
try:
86+
with open_fn(file=path, mode=mode, **open_kwargs) as f:
87+
yield f
88+
finally:
89+
if remove:
90+
os.remove(path)
91+
92+
93+
@Necessary("boto3")
94+
@contextmanager
95+
def open_file_for_write(
96+
path: Union[str, Path],
97+
mode: str = "w",
98+
skip_if_empty: bool = False,
99+
open_fn: Optional[Callable] = None,
100+
logger: Optional[Logger] = None,
101+
open_kwargs: Optional[Dict[str, Any]] = None,
102+
) -> Generator[IO, None, None]:
103+
"""Get a context manager to write to a file that is either on
104+
S3 or local.
105+
106+
Args:
107+
path (Union[str, Path]): The path to the file to write. Can be local
108+
or an S3 path.
109+
mode (str, optional): The mode to open the file in. Defaults to "w".
110+
Only read modes are supported (e.g. 'wb', 'w', ...).
111+
open_fn (Callable, optional): The function to use to open the file.
112+
Defaults to the built-in open function.
113+
logger (Logger, optional): The logger to use. Defaults to the built-in
114+
logger at INFO level.
115+
open_kwargs (Dict[str, Any], optional): Any additional keyword to pass
116+
to the open function. Defaults to None.
117+
"""
118+
119+
parse = urlparse(str(path))
120+
local = None
121+
logger = logger or get_logger()
122+
open_fn = open_fn or open
123+
open_kwargs = open_kwargs or {}
124+
125+
assert "w" in mode or "a" in mode, "Only write/append mode is supported"
126+
127+
try:
128+
if parse.scheme == "file" or parse.scheme == "":
129+
# make enclosing directory if it doesn't exist
130+
Path(path).parent.mkdir(parents=True, exist_ok=True)
131+
132+
with open_fn(file=path, mode=mode, **open_kwargs) as f:
133+
yield f
134+
else:
135+
with NamedTemporaryFile(delete=False, mode=mode) as f:
136+
yield f
137+
local = f.name
138+
finally:
139+
if local is None:
140+
if skip_if_empty and os.stat(path).st_size == 0:
141+
logger.info(f"Skipping empty file {path}")
142+
os.remove(path)
143+
elif parse.scheme == "s3":
144+
dst = f'{parse.netloc}{parse.path.lstrip("/")}'
145+
if skip_if_empty and os.stat(local).st_size == 0:
146+
logger.info(f"Skipping upload to {dst} since {local} is empty")
147+
else:
148+
logger.info(f"Uploading {local} to {dst}")
149+
client = boto3.client("s3")
150+
client.upload_file(local, parse.netloc, parse.path.lstrip("/"))
151+
os.remove(local)
152+
else:
153+
raise ValueError(f"Unsupported scheme {parse.scheme}")
154+
155+
156+
@Necessary("boto3")
157+
def recursively_list_files(
158+
path: Union[str, Path], ignore_hidden_files: bool = True
159+
) -> Iterable[str]:
160+
"""Recursively list all files in the given directory on network prefix
161+
162+
Args:
163+
path (Union[str, Path]): The path to list content at. Can be local
164+
or an S3 path.
165+
ignore_hidden_files (bool, optional): Whether to ignore hidden files
166+
(i.e. files that start with a dot) when listing. Defaults to True.
167+
"""
168+
169+
parse = urlparse(str(path))
170+
171+
if parse.scheme == "s3":
172+
cl = boto3.client("s3")
173+
prefixes = [parse.path.lstrip("/")]
174+
175+
while len(prefixes) > 0:
176+
prefix = prefixes.pop()
177+
paginator = cl.get_paginator("list_objects_v2")
178+
pages = paginator.paginate(Bucket=parse.netloc, Prefix=prefix)
179+
for page in pages:
180+
for obj in page["Contents"]:
181+
if obj["Key"][-1] == "/":
182+
prefixes.append(obj["Key"])
183+
else:
184+
yield f's3://{parse.netloc}/{obj["Key"]}'
185+
186+
elif parse.scheme == "file" or parse.scheme == "":
187+
for root, _, files in os.walk(parse.path):
188+
for f in files:
189+
if ignore_hidden_files and f.startswith("."):
190+
continue
191+
yield os.path.join(root, f)
192+
else:
193+
raise NotImplementedError(f"Unknown scheme: {parse.scheme}")
194+
195+
196+
@Necessary("boto3")
197+
def copy_directory(
198+
src: Union[str, Path],
199+
dst: Union[str, Path],
200+
ignore_hidden_files: bool = True,
201+
logger: Optional[Logger] = None,
202+
):
203+
"""Copy a directory from one location to another. Source or target
204+
locations can be local, remote, or a mix of both.
205+
206+
Args:
207+
src (Union[str, Path]): The location to copy from. Can be local
208+
or a location on S3.
209+
dst (Union[str, Path]): The location to copy to. Can be local or S3.
210+
ignore_hidden_files (bool, optional): Whether to ignore hidden files
211+
on copy. Defaults to True.
212+
logger (Logger, optional): The logger to use. Defaults to the built-in
213+
logger at INFO level.
214+
"""
215+
216+
logger = logger or get_logger()
217+
218+
src = Path(src)
219+
dst = Path(dst)
220+
221+
cnt = 0
222+
223+
for source_path in recursively_list_files(
224+
src, ignore_hidden_files=ignore_hidden_files
225+
):
226+
destination = dst / Path(source_path).relative_to(
227+
src
228+
)
229+
230+
logger.info(f"Copying {source_path} to {destination}; {cnt:,} so far")
231+
232+
with open_file_for_read(source_path, mode="rb") as s:
233+
with open_file_for_write(destination, mode="wb") as d:
234+
d.write(s.read())
235+
236+
cnt += 1
237+
238+
239+
@Necessary("boto3")
240+
def remove_directory(path: Union[str, Path]):
241+
"""Completely remove a directory at the provided path."""
242+
243+
parse = urlparse(str(path))
244+
245+
if parse.scheme == "s3":
246+
client = boto3.client("s3")
247+
for fn in recursively_list_files(path, ignore_hidden_files=False):
248+
parsed = urlparse(str(fn))
249+
client.delete_object(
250+
Bucket=parsed.netloc, Key=parsed.path.lstrip("/")
251+
)
252+
elif parse.scheme == "file" or parse.scheme == "":
253+
shutil.rmtree(path, ignore_errors=True)
254+
else:
255+
raise NotImplementedError(f"Unknown scheme: {parse.scheme}")

0 commit comments

Comments
 (0)