Skip to content

Commit bdbc617

Browse files
authored
Soldni/stride fix (#39)
* fix on how last stride is computed * fixed small issue with collator in case ignore doess not exist * improvements to collator * left padding support * rename support * pretty * tests for tokenizer * fixed issues with datasets 2.8.0 refactor * necessary library syntax fix
1 parent 201e28d commit bdbc617

14 files changed

+334
-66
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.12.0"
3+
version = "0.13.0"
44
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
55
authors = [
66
{name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" },

src/smashed/base/interfaces.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,15 @@
2828

2929
with necessary("datasets", soft=True) as HUGGINGFACE_DATASET_AVAILABLE:
3030
if HUGGINGFACE_DATASET_AVAILABLE or TYPE_CHECKING:
31-
from datasets.arrow_dataset import Batch, Dataset
31+
from datasets.arrow_dataset import Dataset
32+
33+
try:
34+
from datasets.formatting.formatting import LazyBatch
35+
except ImportError:
36+
# pre datasets 2.8.0
37+
from datasets.arrow_dataset import (
38+
Batch as LazyBatch, # pyright: ignore
39+
)
3240
from datasets.iterable_dataset import IterableDataset
3341

3442
HuggingFaceDataset = TypeVar(
@@ -284,12 +292,12 @@ def _map_huggingface_dataset(
284292
else:
285293
return transformed_dataset
286294

287-
@map.add_interface(dataset=Batch)
295+
@map.add_interface(dataset=LazyBatch)
288296
def _map_huggingface_dataset_batch(
289297
self,
290-
dataset: Batch,
298+
dataset: LazyBatch,
291299
**map_kwargs: Any,
292-
) -> Batch:
300+
) -> LazyBatch:
293301
# explicitly casting to a boolean since this is all that is
294302
# supported by the simple mapper.
295303
# TODO[lucas]: maybe support specifying which fields to keep?
@@ -298,7 +306,7 @@ def _map_huggingface_dataset_batch(
298306
or self.always_remove_columns
299307
)
300308

301-
dtview: DataBatchView[Batch, str, Any] = DataBatchView(dataset)
309+
dtview: DataBatchView[LazyBatch, str, Any] = DataBatchView(dataset)
302310

303311
self._check_fields_datasets(
304312
provided_fields=dataset.keys(),

src/smashed/mappers/cache.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313

1414
with necessary("datasets", soft=True) as HUGGINGFACE_DATASET_AVAILABLE:
1515
if HUGGINGFACE_DATASET_AVAILABLE or TYPE_CHECKING:
16-
from datasets.arrow_dataset import Dataset, Batch
16+
from datasets.arrow_dataset import Dataset
17+
18+
try:
19+
from datasets.formatting.formatting import LazyBatch
20+
except ImportError:
21+
# pre datasets 2.8.0
22+
from datasets.arrow_dataset import (
23+
Batch as LazyBatch, # pyright: ignore
24+
)
1725
from datasets.iterable_dataset import IterableDataset
1826
from datasets.fingerprint import disable_caching, enable_caching
1927

@@ -130,8 +138,8 @@ def get_dataset_fingerprint_hf_iterable(
130138
)
131139
return h.hexdigest()
132140

133-
@get_dataset_fingerprint.add_interface(dataset=Batch)
134-
def get_dataset_fingerprint_hf_batch(self, dataset: Batch) -> str:
141+
@get_dataset_fingerprint.add_interface(dataset=LazyBatch)
142+
def get_dataset_fingerprint_hf_batch(self, dataset: LazyBatch) -> str:
135143
raise ValueError(
136144
"Cannot cache a Batch of a HuggingFace Dataset; please "
137145
"cache at the Dataset level instead."
@@ -198,7 +206,7 @@ def _save_hf_it(self, dataset: IterableDataset, path: Path):
198206
"Saving an IterableDataset is not implemented yet"
199207
)
200208

201-
@save_cache.add_interface(dataset=Batch)
209+
@save_cache.add_interface(dataset=LazyBatch)
202210
def _save_hf_batch(self, dataset: Dataset, path: Path):
203211
raise ValueError(
204212
"Cannot cache a Batch of a HuggingFace Dataset; please "
@@ -274,7 +282,9 @@ def _load_list(
274282

275283
if HUGGINGFACE_DATASET_AVAILABLE:
276284

277-
@load_cache.add_interface(dataset=(IterableDataset, Dataset, Batch))
285+
@load_cache.add_interface(
286+
dataset=(IterableDataset, Dataset, LazyBatch)
287+
)
278288
def _load_hf(
279289
self,
280290
path: Path,

src/smashed/mappers/collators.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
2424
fields_pad_ids: Optional[Mapping[str, int]] = None,
2525
unk_fields_pad_id: Optional[int] = None,
26+
left_pad_fields: Optional[Sequence[str]] = None,
2627
):
2728
"""Create a collator.
2829
@@ -41,10 +42,14 @@ def __init__(
4142
unk_fields_pad_id (int, optional): The padding value to use for
4243
any field that is not in fields_pad_ids. If not provided, an
4344
error will be raised if a field is not in fields_pad_ids.
45+
left_pad_fields (Sequence[str], optional): A list of fields to
46+
pad from the left instead of the right. By default, all fields
47+
are padded from the right.
4448
"""
4549
self.fields_pad_ids = fields_pad_ids or {}
4650
self.pad_to_length = pad_to_length
4751
self.unk_fields_pad_id = unk_fields_pad_id
52+
self.left_pad_fields = set(left_pad_fields or [])
4853

4954
if self.unk_fields_pad_id is None and self.fields_pad_ids is None:
5055
raise ValueError(
@@ -145,7 +150,25 @@ def _pad(
145150
pad_value: int,
146151
dim: int = 0,
147152
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
153+
right_pad: bool = True,
148154
) -> torch.Tensor:
155+
"""Pad a sequence of tensors to the same length.
156+
157+
Args:
158+
sequence (Sequence[torch.Tensor]): The sequence of tensors to pad.
159+
It is assumed that all tensors in the sequence have the same
160+
type; if not an error might be raised somewhere.
161+
pad_value (int): The value to use for padding.
162+
dim (int, optional): The dimension we are collating on. Defaults
163+
to 0.
164+
pad_to_length (Union[int, Sequence[int]], optional): If provided,
165+
pad all sequences to this length. If provided as a sequence,
166+
we assume we should pad each dimension to the corresponding
167+
length. If None, sequences will be padded to the length of the
168+
longest sequence. Defaults to None.
169+
right_pad (bool, optional): If True, pad to the right. If False,
170+
pad to the left. Defaults to True.
171+
"""
149172

150173
# make sure type of input is right
151174
if not (
@@ -192,13 +215,14 @@ def _pad(
192215
pad_shapes = tuple(
193216
tuple(
194217
chain.from_iterable(
195-
(0, m - s)
218+
(0, m - s) if right_pad else (m - s, 0)
196219
for s, m in zip(t.size()[::-1], max_lengths[::-1])
197220
)
198221
)
199222
# we do padding shapes for each tensor
200223
for t in sequence
201224
)
225+
202226
# call each pad on each of the tensors with the appropriate padding
203227
to_stack = tuple(
204228
torch.nn.functional.pad(
@@ -218,6 +242,7 @@ def transform( # type: ignore
218242
sequence=list_of_tensors,
219243
pad_value=self._get_padding_value(field_name=field_name),
220244
pad_to_length=self.pad_to_length,
245+
right_pad=(field_name not in self.left_pad_fields),
221246
)
222247
for field_name, list_of_tensors in data.items()
223248
}
@@ -270,23 +295,29 @@ def _get_list_shape_recursive(
270295
# this iterator will yield the shape of each element in the sequence
271296
inner_dims = (self._get_list_shape_recursive(s) for s in sequence)
272297

273-
# the acutal shape is the maximum of the inner dims
298+
# the actual shape is the maximum of the inner dims
274299
inner_shape = tuple(max(dims) for dims in zip(*inner_dims))
275300

276301
return (len(sequence), *inner_shape)
277302

278303
def _pad_recursive(
279-
self, sequence: List[Any], shape: Sequence[int], padding_symbol: Any
304+
self,
305+
sequence: List[Any],
306+
shape: Sequence[int],
307+
padding_symbol: Any,
308+
pad_right: bool = True,
280309
) -> List[Any]:
281310
"""Recursively pads a list of [lists, ...].
282311
283312
Args:
284313
sequence (List[Any]): The list to pad.
285314
shape (Sequence[int]): The shape to pad to.
286315
padding_symbol (Any): The symbol to pad with.
316+
pad_right (bool, optional): If True, pads to the right. If False,
317+
pads to the left. Defaults to True.
287318
288319
Returns:
289-
List[Any]: _description_
320+
List[Any]: The padded list.
290321
"""
291322

292323
if len(shape) < 2:
@@ -321,7 +352,11 @@ def _pad_recursive(
321352
#
322353
# We do that in the following line:
323354
sequence_with_brand_new_padding = (
355+
# the side we pad depends on wether pad_right is True or False
324356
sub_seq + [nested_pad_symbol] * (dim_to_pad_shape - len(sub_seq))
357+
if pad_right
358+
else [nested_pad_symbol] * (dim_to_pad_shape - len(sub_seq))
359+
+ sub_seq
325360
for sub_seq in sequence
326361
)
327362

@@ -342,6 +377,7 @@ def _pad(
342377
self: "ListCollatorMapper",
343378
seq_of_seq_to_pad: List[Any],
344379
padding_symbol: Any,
380+
pad_right: bool = True,
345381
) -> List[Any]:
346382

347383
padding_shape = self._get_list_shape_recursive(seq_of_seq_to_pad)
@@ -367,6 +403,7 @@ def _pad(
367403
sequence=seq_of_seq_to_pad,
368404
shape=padding_shape,
369405
padding_symbol=padding_symbol,
406+
pad_right=pad_right,
370407
)
371408
return padded_sequence
372409

@@ -377,6 +414,7 @@ def transform(self, data: TransformElementType) -> TransformElementType:
377414
field_name: self._pad(
378415
seq_of_seq_to_pad=field_value,
379416
padding_symbol=self._get_padding_value(field_name=field_name),
417+
pad_right=(field_name not in self.left_pad_fields),
380418
)
381419
for field_name, field_value in data.items()
382420
}

src/smashed/mappers/fields.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ def __init__(
2727
self,
2828
keep_fields: Optional[List[str]] = None,
2929
drop_fields: Optional[List[str]] = None,
30+
raise_on_missing: bool = True,
3031
):
3132
"""
3233
Args:
3334
keep_fields (List[str]): Fields to keep, all other fields
3435
are dropped. Defaults to [].
3536
drop_fields (List[str]): Fields to drop, all other fields
3637
are kept. Defaults to [].
38+
raise_on_missing (bool): Whether to raise an error if a field
39+
is missing. Defaults to True.
3740
"""
3841

3942
# xor between keep_fields and remove_fields
@@ -42,16 +45,22 @@ def __init__(
4245
):
4346
raise ValueError("Must specify `keep_fields` or `drop_fields`")
4447

45-
super().__init__(input_fields=drop_fields, output_fields=keep_fields)
48+
self.keep_fields = dict.fromkeys(keep_fields) if keep_fields else None
49+
self.drop_fields = dict.fromkeys(drop_fields) if drop_fields else None
50+
51+
super().__init__(
52+
input_fields=drop_fields if raise_on_missing else None,
53+
output_fields=keep_fields if raise_on_missing else None,
54+
)
4655

4756
def transform(self, data: TransformElementType) -> TransformElementType:
48-
if self.input_fields:
57+
if self.drop_fields:
4958
new_data = {
50-
k: v for k, v in data.items() if k not in self.input_fields
59+
k: v for k, v in data.items() if k not in self.drop_fields
5160
}
5261

53-
elif self.output_fields:
54-
new_data = {k: data[k] for k in self.output_fields}
62+
elif self.keep_fields:
63+
new_data = {k: data[k] for k in data if k in self.keep_fields}
5564

5665
else:
5766
raise ValueError("Must specify `keep_fields` or `drop_fields`")

src/smashed/mappers/glom.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
with necessary("datasets", soft=True) as DATASETS_AVAILABLE:
1010
if DATASETS_AVAILABLE:
11-
from datasets.arrow_dataset import Example
11+
try:
12+
from datasets.formatting.formatting import LazyRow
13+
except ImportError:
14+
# pre datasets 2.8.0
15+
from datasets.arrow_dataset import (
16+
Example as LazyRow, # pyright: ignore
17+
)
1218

1319

1420
class ExtendGlommerMixin:
@@ -26,10 +32,10 @@ def glommer(self) -> glom.Glommer:
2632

2733
if DATASETS_AVAILABLE:
2834
glommer.register(
29-
target_type=Example,
30-
get=Example.__getitem__,
31-
iter=Example.__iter__,
32-
exact=Example.__eq__,
35+
target_type=LazyRow,
36+
get=LazyRow.__getitem__,
37+
iter=LazyRow.__iter__,
38+
exact=LazyRow.__eq__,
3339
)
3440

3541
glommer.register(

src/smashed/mappers/multiseq.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def transform(
396396
) >= self.max_stride_count
397397

398398
if stride_too_long or stride_has_too_many_seqs:
399+
399400
yield {
400401
k: (
401402
# if a list of fields to strides has been provided,
@@ -423,7 +424,21 @@ def transform(
423424
cumulative_stride_length += current_seq_length
424425

425426
# yield the last sequence
426-
out = {k: v[seq_pos_start:] for k, v in sample.items()}
427+
out = {
428+
k: (
429+
# same logic as above: if a list of fields to strides
430+
# has been provided, then only stride this field if it
431+
# is in the list and duplicate if it is not; if no list
432+
# of fields to stride has been provided, then stride all.
433+
v[seq_pos_start:]
434+
if (
435+
self.fields_to_stride is None
436+
or k in self.fields_to_stride
437+
)
438+
else v
439+
)
440+
for k, v in sample.items()
441+
}
427442

428443
yield out
429444

@@ -434,7 +449,7 @@ def __init__(
434449
single_value_field: str,
435450
like_field: str = "input_ids",
436451
strategy: Literal["first", "last", "all"] = "first",
437-
padding_id: Union[int, float] = -100,
452+
padding_id: Any = -100,
438453
) -> None:
439454
"""Mapper to create a sequence of values from single value.
440455
Useful when casting a sequence classification task to a sequence

0 commit comments

Comments
 (0)