Skip to content

Commit 5fd542c

Browse files
authored
Added tools to stream from S3 (#57)
* added compression lib * renamed * style * added extra function to decompress * added option to compress a stream * docs
1 parent d2b214a commit 5fd542c

File tree

14 files changed

+748
-252
lines changed

14 files changed

+748
-252
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.19.6"
3+
version = "0.20.0"
44
description = """\
55
SMASHED is a toolkit designed to apply transformations to samples in \
66
datasets, such as fields extraction, tokenization, prompting, batching, \
@@ -97,6 +97,7 @@ dev = [
9797
"ipdb>=0.13.0",
9898
"flake8-pyi>=22.8.1",
9999
"Flake8-pyproject>=1.1.0",
100+
"moto[ec2,s3,all] >= 4.0.0",
100101
]
101102
remote = [
102103
"smart-open>=5.2.1",

src/smashed/utils/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
from .caching import get_cache_dir
22
from .convert import bytes_from_int, int_from_bytes
3+
from .io_utils import (
4+
MultiPath,
5+
open_file_for_read,
6+
open_file_for_write,
7+
recursively_list_files,
8+
)
39
from .version import get_name, get_name_and_version, get_version
410
from .warnings import SmashedWarnings
511
from .wordsplitter import BlingFireSplitter, WhitespaceSplitter
@@ -12,6 +18,10 @@
1218
"get_name",
1319
"get_version",
1420
"int_from_bytes",
21+
"MultiPath",
22+
"open_file_for_read",
23+
"open_file_for_write",
24+
"recursively_list_files",
1525
"SmashedWarnings",
1626
"WhitespaceSplitter",
1727
]

src/smashed/utils/install_blingfire_macos.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#! /usr/bin/env python3
22

3+
4+
import platform
35
from subprocess import call
6+
from warnings import warn
47

58
BASH_SCRIPT = """
69
#! /usr/bin/env bash
@@ -39,7 +42,17 @@
3942

4043

4144
def main():
42-
call(BASH_SCRIPT.strip(), shell=True)
45+
# check if we are on MacOS
46+
if platform.system() != "Darwin":
47+
warn("This script is only meant to be run on MacOS; skipping...")
48+
return
49+
50+
# check that architecture is arm64
51+
if platform.machine() != "arm64":
52+
warn("This script is only meant to be run on arm64; skipping...")
53+
return
54+
55+
return call(BASH_SCRIPT.strip(), shell=True)
4356

4457

4558
if __name__ == "__main__":
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from .closures import upload_on_success
2+
from .compression import compress_stream, decompress_stream
3+
from .multipath import MultiPath
4+
from .operations import (
5+
copy_directory,
6+
exists,
7+
is_dir,
8+
is_file,
9+
open_file_for_read,
10+
open_file_for_write,
11+
recursively_list_files,
12+
remove_directory,
13+
remove_file,
14+
stream_file_for_read,
15+
)
16+
17+
__all__ = [
18+
"compress_stream",
19+
"copy_directory",
20+
"decompress_stream",
21+
"exists",
22+
"is_dir",
23+
"is_file",
24+
"MultiPath",
25+
"open_file_for_read",
26+
"open_file_for_write",
27+
"recursively_list_files",
28+
"remove_directory",
29+
"remove_file",
30+
"stream_file_for_read",
31+
"upload_on_success",
32+
]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from contextlib import AbstractContextManager, ExitStack
2+
from functools import partial
3+
from tempfile import TemporaryDirectory
4+
from typing import Callable, Optional, TypeVar
5+
6+
from typing_extensions import Concatenate, ParamSpec
7+
8+
from .multipath import MultiPath
9+
from .operations import PathType, copy_directory, remove_directory
10+
11+
T = TypeVar("T")
12+
P = ParamSpec("P")
13+
14+
15+
class upload_on_success(AbstractContextManager):
16+
"""Context manager to upload a directory of results to a remote
17+
location if the execution in the context manager is successful.
18+
19+
You can use this class in two ways:
20+
21+
1. As a context manager
22+
23+
```python
24+
25+
with upload_on_success('s3://my-bucket/my-results') as path:
26+
# run training, save temporary results in `path`
27+
...
28+
```
29+
30+
2. As a function decorator
31+
32+
```python
33+
@upload_on_success('s3://my-bucket/my-results')
34+
def my_function(path: str, ...)
35+
# run training, save temporary results in `path`
36+
```
37+
38+
You can specify a local destination by passing `local_path` to
39+
`upload_on_success`. Otherwise, a temporary directory is created for you.
40+
"""
41+
42+
def __init__(
43+
self,
44+
remote_path: PathType,
45+
local_path: Optional[PathType] = None,
46+
keep_local: bool = False,
47+
) -> None:
48+
"""Constructor for upload_on_success context manager
49+
50+
Args:
51+
remote_path (str or urllib.parse.ParseResult): The remote location
52+
to upload to (e.g., an S3 prefix for a bucket you have
53+
access to).
54+
local_path (str or Path): The local path where to temporarily
55+
store files before upload. If not provided, a temporary
56+
directory is created for you and returned by the context
57+
manager. It will be deleted at the end of the context
58+
(unless keep_local is set to True). Defaults to None
59+
keep_local (bool, optional): Whether to keep the local results
60+
as well as uploading to the remote path. Only available
61+
if `local_path` is provided.
62+
"""
63+
64+
self._ctx = ExitStack()
65+
self.remote_path = remote_path
66+
self.local_path = MultiPath.parse(
67+
local_path or self._ctx.enter_context(TemporaryDirectory())
68+
)
69+
if local_path is None and keep_local:
70+
raise ValueError(
71+
"Cannot keep local destination if `local_path` is None"
72+
)
73+
self.keep_local = keep_local
74+
75+
super().__init__()
76+
77+
def _decorated(
78+
self,
79+
func: Callable[Concatenate[str, P], T],
80+
*args: P.args,
81+
**kwargs: P.kwargs,
82+
) -> T:
83+
with type(self)(
84+
local_path=self.local_path,
85+
remote_path=self.remote_path,
86+
keep_local=self.keep_local,
87+
) as path:
88+
output = func(path.as_str, *args, **kwargs)
89+
return output
90+
91+
def __call__(
92+
self, func: Callable[Concatenate[str, P], T]
93+
) -> Callable[P, T]:
94+
return partial(self._decorated, func=func) # type: ignore
95+
96+
def __enter__(self):
97+
return self.local_path
98+
99+
def __exit__(self, exc_type, exc_value, traceback):
100+
if exc_type is None:
101+
# all went well, so we copy the local directory to the remote
102+
copy_directory(src=self.local_path, dst=self.remote_path)
103+
104+
if not self.keep_local:
105+
remove_directory(self.local_path)
106+
107+
self._ctx.close()
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import gzip as gz
2+
import io
3+
from contextlib import contextmanager
4+
from typing import IO, Iterator, Literal, Optional, cast
5+
6+
from .io_wrappers import BytesZLibDecompressorIO, TextZLibDecompressorIO
7+
8+
9+
@contextmanager
10+
def decompress_stream(
11+
stream: IO,
12+
mode: Literal["r", "rt", "rb"] = "rt",
13+
encoding: Optional[str] = "utf-8",
14+
errors: str = "strict",
15+
chunk_size: int = io.DEFAULT_BUFFER_SIZE,
16+
gzip: bool = True,
17+
) -> Iterator[IO]:
18+
out: io.IOBase
19+
20+
if mode == "rb" or mode == "r":
21+
out = BytesZLibDecompressorIO(
22+
stream=stream, chunk_size=chunk_size, gzip=gzip
23+
)
24+
elif mode == "rt":
25+
assert encoding is not None, "encoding must be provided for text mode"
26+
out = TextZLibDecompressorIO(
27+
stream=stream,
28+
chunk_size=chunk_size,
29+
gzip=gzip,
30+
encoding=encoding,
31+
errors=errors,
32+
)
33+
else:
34+
raise ValueError(f"Unsupported mode: {mode}")
35+
36+
# cast to IO to satisfy mypy, then yield
37+
yield cast(IO, out)
38+
39+
# Flush and close the stream
40+
out.close()
41+
42+
43+
@contextmanager
44+
def compress_stream(
45+
stream: IO,
46+
mode: Literal["w", "wt", "wb"] = "wt",
47+
encoding: Optional[str] = "utf-8",
48+
errors: str = "strict",
49+
gzip: bool = True,
50+
) -> Iterator[IO]:
51+
52+
assert gzip, "Only gzip compression is supported at this time"
53+
54+
if mode == "wb" or mode == "w":
55+
out = gz.open(stream, mode=mode)
56+
elif mode == "wt":
57+
assert encoding is not None, "encoding must be provided for text mode"
58+
out = gz.open(stream, mode=mode, encoding=encoding, errors=errors)
59+
else:
60+
raise ValueError(f"Unsupported mode: {mode}")
61+
62+
# cast to IO to satisfy mypy, then yield
63+
yield cast(IO, out)
64+
65+
# Flush and close the stream
66+
out.close()

0 commit comments

Comments
 (0)