Skip to content

Commit 2ab965f

Browse files
authored
Added TruncateMultipleNestedFieldsMapper (#50)
* added TruncateMultipleNestedFieldsMapper * accidentally commited .dmypy * formatting * lowered requirements to 3.8 * added new ref to mapper * style
1 parent d63dfc3 commit 2ab965f

File tree

5 files changed

+212
-6
lines changed

5 files changed

+212
-6
lines changed

pyproject.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
[project]
22
name = "smashed"
3-
version = "0.15.5"
3+
version = "0.16.0"
44
description = """\
55
SMASHED is a toolkit designed to apply transformations to samples in \
66
datasets, such as fields extraction, tokenization, prompting, batching, \
77
and more. Supports datasets from Huggingface, torchdata iterables, or \
88
simple lists of dictionaries.\
99
"""
10-
# authors = [
11-
# {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"},
12-
# {name = "Luca Soldaini", email = "luca@soldaini.net"}
13-
# ]
1410
license = {text = "Apache-2.0"}
1511
readme = "README.md"
16-
requires-python = ">=3.9"
12+
requires-python = ">=3.8"
1713
dependencies = [
1814
"torch>=1.9",
1915
"transformers>=4.5",

src/smashed/mappers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
FillEncodedPromptMapper,
4444
FillTextPromptMapper,
4545
TruncateMultipleFieldsMapper,
46+
TruncateMultipleNestedFieldsMapper,
4647
)
4748
from .promptsource import FewShotJinjaMapper, JinjaMapper, PromptsourceMapper
4849
from .shape import (
@@ -112,6 +113,7 @@
112113
"TokenTypeIdsSequencePaddingMapper",
113114
"Torch2PythonMapper",
114115
"TruncateMultipleFieldsMapper",
116+
"TruncateMultipleNestedFieldsMapper",
115117
"TruncateSingleFieldMapper",
116118
"UnpackingMapper",
117119
"ValidUnicodeMapper",

src/smashed/mappers/prompting.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
99

1010
from ..base import SingleBaseMapper, TransformElementType
11+
from ..utils.shape_utils import flatten_with_indices, reconstruct_from_indices
1112
from .tokenize import GetTokenizerOutputFieldsAndNamesMixIn
1213

1314
__all__ = [
1415
"EncodeFieldsMapper",
1516
"FillEncodedPromptMapper",
1617
"FillTextPromptMapper",
1718
"TruncateMultipleFieldsMapper",
19+
"TruncateMultipleNestedFieldsMapper",
1820
]
1921

2022

@@ -291,6 +293,31 @@ def transform(self, data: TransformElementType) -> TransformElementType:
291293
return output
292294

293295

296+
class TruncateMultipleNestedFieldsMapper(TruncateMultipleFieldsMapper):
297+
"""Like TruncateMultipleFieldsMapper, but works on nested fields."""
298+
299+
def transform(self, data: TransformElementType) -> TransformElementType:
300+
# gather fields to truncate in flatted_data, keep track of
301+
# the indices of the fields in flatted_index
302+
flatted_index: dict = {}
303+
flatted_data: dict = {}
304+
305+
for k in self.input_fields:
306+
flatted_data[k], flatted_index[k] = flatten_with_indices(data[k])
307+
308+
flatted_output = super().transform(flatted_data)
309+
310+
output = {
311+
k: (
312+
reconstruct_from_indices(flatted_output[k], flatted_index[k])
313+
if k in flatted_output
314+
else data[k]
315+
)
316+
for k in data
317+
}
318+
return output
319+
320+
294321
@dataclass
295322
class PromptSegment:
296323
"""Class to represent a segment of a prompt. Not meant to be used

src/smashed/utils/shape_utils.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from collections.abc import Sequence as SequenceABC
2+
from typing import Any, List, Sequence, Tuple, TypeVar, Union, cast
3+
4+
from typing_extensions import TypeAlias
5+
6+
T = TypeVar("T")
7+
8+
LocTupleType: TypeAlias = Tuple[int, int]
9+
KeysType: TypeAlias = Union[LocTupleType, List["KeysType"]]
10+
NestedSequenceType: TypeAlias = Union[
11+
Sequence[T], Sequence["NestedSequenceType[T]"]
12+
]
13+
NestedListType: TypeAlias = Union[List[T], List["NestedListType[T]"]]
14+
15+
16+
def is_sequence_but_not_str(obj: Any) -> bool:
17+
"""Check if an object is a sequence but not a string."""
18+
return isinstance(obj, SequenceABC) and not isinstance(obj, (str, bytes))
19+
20+
21+
def flatten_with_indices(
22+
sequence: NestedSequenceType[T], __offset: int = 0
23+
) -> Tuple[List[T], Union[KeysType, None]]:
24+
"""Recursively flatten an iterable of iterables, returning both the
25+
flatten list, as well as the indices of the original list.
26+
27+
Args:
28+
sequence (NestedSequenceType[T]): Either a sequence or a sequence
29+
of sequences; if a sequence of sequences, will be flattened.
30+
__offset (int, optional): Internal offset to keep track of the
31+
position in the flattened list. Defaults to 0; DO NOT CHANGE.
32+
33+
Raises:
34+
ValueError: If the sequence contains both sequences and
35+
non-sequences.
36+
37+
Returns:
38+
List[T]: The flattened list; if the original list was not nested,
39+
will be the same as the original list.
40+
Union[KeysType, None]: The indices of the original list; if the
41+
original list was not nested, will be None.
42+
"""
43+
44+
it = iter(sequence)
45+
flattened: list = []
46+
keys: list = []
47+
is_nested_sequence = is_already_flat = False
48+
49+
while True:
50+
try:
51+
item = next(it)
52+
except StopIteration:
53+
break
54+
55+
if is_sequence_but_not_str(item):
56+
if is_already_flat:
57+
raise ValueError(
58+
"Cannot mix sequences and non-sequences when flattening."
59+
)
60+
is_nested_sequence = True
61+
62+
offset = len(flattened) + __offset
63+
# manual casting bc we know this is a sequence (see function
64+
# is_sequence_but_not_str) but if we don't cast mypy is going
65+
# to complain.
66+
item = cast(NestedSequenceType[T], item)
67+
68+
# must use type: ignore here because mypy doesn't like using
69+
# the __offset kwarg (which is a good idea in general but
70+
# we nee to use it during recursive calls)
71+
sub_flattened, sub_keys = flatten_with_indices( # type: ignore
72+
sequence=item, __offset=offset
73+
)
74+
75+
if sub_keys is None:
76+
sub_keys = (offset, offset + len(sub_flattened))
77+
78+
keys.append(sub_keys)
79+
flattened.extend(sub_flattened)
80+
else:
81+
if is_nested_sequence:
82+
raise ValueError(
83+
"Cannot mix sequences and non-sequences when flattening."
84+
)
85+
is_already_flat = True
86+
87+
flattened.append(item)
88+
89+
return flattened, (keys or None)
90+
91+
92+
def reconstruct_from_indices(
93+
flattened: List[T], keys: Union[KeysType, None]
94+
) -> NestedListType[T]:
95+
"""Recursively reconstruct a list from a flattened list and the keys that
96+
were returned from recursively_flatten_with_indices.
97+
98+
Args:
99+
flattened (List[T]): A flat list of items.
100+
101+
"""
102+
103+
if keys is None:
104+
return flattened
105+
106+
reconstructed: list = []
107+
for key in keys:
108+
if isinstance(key, list):
109+
reconstructed.append(reconstruct_from_indices(flattened, key))
110+
elif isinstance(key, tuple):
111+
start, end = key
112+
reconstructed.append(flattened[start:end])
113+
else:
114+
raise ValueError(
115+
f"Invalid key type: expected tuple or list, got {type(key)}"
116+
)
117+
118+
return reconstructed

tests/test_shape_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
3+
from smashed.utils.shape_utils import (
4+
flatten_with_indices,
5+
reconstruct_from_indices,
6+
)
7+
8+
9+
class TestFlatten(unittest.TestCase):
10+
def test_flatten(self):
11+
li = [
12+
[0, 1, 2, 3],
13+
["4", "5"],
14+
[6, 7],
15+
["8"],
16+
[9.0, 10.0, 11.0, 12.0, 13.0],
17+
[],
18+
[14, 15, 16],
19+
[17, 18, 19, "20"],
20+
[21, "22"],
21+
[""],
22+
[23, 24, 25, 26, 27, 28, 29, "30"],
23+
]
24+
25+
fl, idx = flatten_with_indices(li)
26+
new_li = reconstruct_from_indices(fl, idx)
27+
28+
self.assertEqual(li, new_li)
29+
30+
def test_deeply_nested(self):
31+
# a nested 4-deep nested list
32+
li = [
33+
[[[0, 1, 2, 3], ["4", "5"]], [[6, 7], ["8"]]],
34+
[
35+
[[9.0, 10.0, 11.0, 12.0, 13.0], []],
36+
[[14, 15, 16], [17, 18, 19, "20"], [21, "22"], [""]],
37+
[[23, 24, 25, 26, 27, 28, 29, "30"]],
38+
],
39+
]
40+
41+
fl, idx = flatten_with_indices(li)
42+
new_li = reconstruct_from_indices(fl, idx)
43+
44+
self.assertEqual(li, new_li)
45+
46+
def test_empty(self):
47+
li = []
48+
fl, idx = flatten_with_indices(li)
49+
new_li = reconstruct_from_indices(fl, idx)
50+
51+
self.assertEqual(li, new_li)
52+
53+
def test_already_flat(self):
54+
li = [0, 1, 2, 3]
55+
fl, idx = flatten_with_indices(li)
56+
new_li = reconstruct_from_indices(fl, idx)
57+
58+
self.assertEqual(li, new_li)
59+
60+
def test_error_when_mixed(self):
61+
li = [0, 1, 2, 3, [4, 5, 6]]
62+
with self.assertRaises(ValueError):
63+
flatten_with_indices(li)

0 commit comments

Comments
 (0)