Skip to content

Commit 67a40c0

Browse files
authored
Removed transformers dependency (#52)
* made transformers dep optional * resolved issue with versioning
1 parent 85b9586 commit 67a40c0

File tree

9 files changed

+69
-42
lines changed

9 files changed

+69
-42
lines changed

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.17.1"
3+
version = "0.18.0"
44
description = """\
55
SMASHED is a toolkit designed to apply transformations to samples in \
66
datasets, such as fields extraction, tokenization, prompting, batching, \
@@ -12,8 +12,7 @@ readme = "README.md"
1212
requires-python = ">=3.8"
1313
dependencies = [
1414
"torch>=1.9",
15-
"transformers>=4.5",
16-
"necessary>=0.3.3",
15+
"necessary>=0.4.1",
1716
"trouting>=0.3.3",
1817
"ftfy>=6.1.1",
1918
"platformdirs>=2.5.0",
@@ -104,10 +103,12 @@ remote = [
104103
"boto3>=1.25.5",
105104
]
106105
datasets = [
107-
"datasets>=2.8.0",
108-
"dill>=0.3.0",
106+
"transformers>=4.5",
107+
"datasets>=2.8.0",
108+
"dill>=0.3.0",
109109
]
110110
prompting = [
111+
"transformers>=4.5",
111112
"promptsource>=0.2.3",
112113
"blingfire>=0.1.8",
113114
]

src/smashed/contrib/squad.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from bisect import bisect_left, bisect_right
22
from typing import Any, Literal, Optional, Sequence, Tuple, TypeVar, Union
33

4-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
4+
from necessary import necessary, Necessary
55

66
from smashed.base import BaseRecipe, SingleBaseMapper, TransformElementType
77
from smashed.base.mappers import ChainableMapperMixIn
@@ -16,6 +16,10 @@
1616
)
1717
from smashed.recipes.prompting import PromptingRecipe
1818

19+
with necessary("transformers", soft=True):
20+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
21+
22+
1923
__all__ = [
2024
"AddEvidencesLocationMapper",
2125
"ConcatenateContextMapper",
@@ -317,7 +321,7 @@ def strider_mapper(self, **kwargs) -> SingleSequenceStriderMapper:
317321
def __init__(
318322
self,
319323
*args,
320-
tokenizer: PreTrainedTokenizerBase,
324+
tokenizer: 'PreTrainedTokenizerBase',
321325
context_field: str = "context",
322326
location_field: str = "locations",
323327
**kwargs,
@@ -331,13 +335,14 @@ def __init__(
331335
C = TypeVar("C", bound=ChainableMapperMixIn)
332336

333337

338+
@Necessary("transformers")
334339
class SquadPromptTrainRecipe(BaseRecipe):
335340
def unpacking(self, pipeline: C, **kwargs: Any) -> C:
336341
return pipeline >> UnpackingMapper(**kwargs)
337342

338343
def __init__(
339344
self,
340-
tokenizer: PreTrainedTokenizerBase,
345+
tokenizer: 'PreTrainedTokenizerBase',
341346
source_template: str,
342347
context_length: int,
343348
context_stride: int,
@@ -439,11 +444,12 @@ def __init__(
439444
self.chain(pipeline)
440445

441446

447+
@Necessary("transformers")
442448
class SquadPromptValidRecipe(SquadPromptTrainRecipe):
443449
def __init__(
444450
self,
445451
*args,
446-
tokenizer: PreTrainedTokenizerBase,
452+
tokenizer: 'PreTrainedTokenizerBase',
447453
target_output_name: Optional[str] = None,
448454
answer_field: str = "answers",
449455
**kwargs,

src/smashed/mappers/collators.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
55

66
import torch
7-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
7+
from necessary import necessary
88

99
from ..base import SingleBaseMapper, TransformElementType
1010
from ..base.abstract import AbstractBaseMapper
1111

12+
with necessary("transformers", soft=True):
13+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
14+
15+
1216
__all__ = [
1317
"ListCollatorMapper",
1418
"TensorCollatorMapper",
@@ -87,7 +91,7 @@ def collate(
8791
class FromTokenizerMixIn(BaseCollator):
8892
def __init__(
8993
self,
90-
tokenizer: PreTrainedTokenizerBase,
94+
tokenizer: 'PreTrainedTokenizerBase',
9195
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
9296
fields_pad_ids: Optional[Mapping[str, int]] = None,
9397
unk_fields_pad_id: Optional[int] = None,

src/smashed/mappers/decoding.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@
66

77
from typing import Any, Dict, Optional, Sequence, Union
88

9-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
9+
from necessary import necessary
1010

1111
from ..base import SingleBaseMapper, TransformElementType
1212

13+
14+
with necessary("transformers", soft=True):
15+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
16+
17+
1318
__all__ = ["DecodingMapper"]
1419

1520

1621
class DecodingMapper(SingleBaseMapper):
1722
def __init__(
1823
self,
19-
tokenizer: PreTrainedTokenizerBase,
24+
tokenizer: 'PreTrainedTokenizerBase',
2025
fields: Union[str, Sequence[str]],
2126
decode_batch: bool = False,
2227
skip_special_tokens: bool = False,

src/smashed/mappers/multiseq.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
Union,
1313
)
1414

15-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
15+
from necessary import necessary
1616

1717
from ..base import BatchedBaseMapper, SingleBaseMapper, TransformElementType
1818

19+
with necessary('transformers', soft=True):
20+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
21+
1922

2023
class TokensSequencesPaddingMapper(SingleBaseMapper):
2124
bos: List[int]
@@ -24,7 +27,7 @@ class TokensSequencesPaddingMapper(SingleBaseMapper):
2427

2528
def __init__(
2629
self,
27-
tokenizer: PreTrainedTokenizerBase,
30+
tokenizer: 'PreTrainedTokenizerBase',
2831
input_field: str = "input_ids",
2932
) -> None:
3033
"""Mapper that add BOS/SEP/EOS sequences of tokens.
@@ -42,7 +45,7 @@ def __init__(
4245

4346
@staticmethod
4447
def _find_special_token_ids(
45-
tokenizer: PreTrainedTokenizerBase,
48+
tokenizer: 'PreTrainedTokenizerBase',
4649
) -> Tuple[List[int], List[int], List[int]]:
4750
"""By default, tokenizers only know how to concatenate 2 fields
4851
as input; However, for our purposes, we might care about more than
@@ -99,7 +102,7 @@ def transform(self, data: TransformElementType) -> TransformElementType:
99102
class AttentionMaskSequencePaddingMapper(TokensSequencesPaddingMapper):
100103
def __init__(
101104
self,
102-
tokenizer: PreTrainedTokenizerBase,
105+
tokenizer: 'PreTrainedTokenizerBase',
103106
input_field: str = "attention_mask",
104107
) -> None:
105108
"""Mapper to add BOS/SEP/EOS tokens to an attention mask sequence.
@@ -121,7 +124,7 @@ def __init__(
121124
class TokenTypeIdsSequencePaddingMapper(TokensSequencesPaddingMapper):
122125
def __init__(
123126
self,
124-
tokenizer: PreTrainedTokenizerBase,
127+
tokenizer: 'PreTrainedTokenizerBase',
125128
input_field: str = "token_type_ids",
126129
) -> None:
127130
"""Mapper to add BOS/SEP/EOS tokens to a token type ids sequence.
@@ -295,7 +298,7 @@ def __init__(
295298
fields_to_stride: Optional[List[str]] = None,
296299
max_length: Optional[int] = None,
297300
extra_length_per_seq: Optional[int] = None,
298-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
301+
tokenizer: Optional['PreTrainedTokenizerBase'] = None,
299302
max_step: Optional[int] = None,
300303
) -> None:
301304
"""Mapper to create multiple subset sequences from a single sequence

src/smashed/mappers/prompting.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from string import Formatter
55
from typing import Dict, List, Literal, Optional, Sequence, Union
66

7-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
8-
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
7+
from necessary import necessary
98

109
from ..base import SingleBaseMapper, TransformElementType
1110
from ..utils.shape_utils import flatten_with_indices, reconstruct_from_indices
1211
from .tokenize import GetTokenizerOutputFieldsAndNamesMixIn
1312

13+
with necessary("transformers", soft=True):
14+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
15+
1416
__all__ = [
1517
"EncodeFieldsMapper",
1618
"FillEncodedPromptMapper",
@@ -23,7 +25,7 @@
2325
class EncodeFieldsMapper(SingleBaseMapper):
2426
"""Simply encodes the fields in the input data using the tokenizer."""
2527

26-
tokenizer: PreTrainedTokenizerBase
28+
tokenizer: 'PreTrainedTokenizerBase'
2729
is_split_into_words: bool
2830
fields_to_encode: Dict[str, None]
2931

@@ -35,7 +37,7 @@ class EncodeFieldsMapper(SingleBaseMapper):
3537
def __init__(
3638
self,
3739
fields_to_encode: Sequence[str],
38-
tokenizer: PreTrainedTokenizerBase,
40+
tokenizer: 'PreTrainedTokenizerBase',
3941
is_split_into_words: bool = False,
4042
fields_to_return_offset_mapping: Union[Sequence[str], bool] = False,
4143
offset_prefix: str = "offset",
@@ -61,13 +63,16 @@ def __init__(
6163
new field with offsets. Defaults to "pos_start".
6264
"""
6365

64-
if fields_to_return_offset_mapping and not isinstance(
65-
tokenizer, PreTrainedTokenizerFast
66-
):
67-
raise TypeError(
68-
"return_offsets_mapping is only supported for fast tokenizers,"
69-
" i.e. those that inherit from PreTrainedTokenizerFast."
70-
)
66+
if fields_to_return_offset_mapping and necessary("transformers"):
67+
from transformers.tokenization_utils_fast \
68+
import PreTrainedTokenizerFast
69+
70+
if not isinstance(tokenizer, PreTrainedTokenizerFast):
71+
raise TypeError(
72+
"return_offsets_mapping is only supported for fast "
73+
"tokenizers, i.e., those that inherit from "
74+
"PreTrainedTokenizerFast."
75+
)
7176

7277
if isinstance(fields_to_return_offset_mapping, bool):
7378
# if user provides true, it means they want to return the
@@ -139,7 +144,7 @@ def __init__(
139144
self,
140145
fields_to_truncate: List[str],
141146
fields_to_preserve: Optional[List[str]] = None,
142-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
147+
tokenizer: Optional['PreTrainedTokenizerBase'] = None,
143148
max_length: Optional[int] = None,
144149
length_penalty: int = 0,
145150
strategy: Union[Literal["longest"], Literal["uniform"]] = "longest",

src/smashed/mappers/tokenize.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
import unicodedata
88
from typing import Any, Dict, List, Optional
99

10-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
10+
from necessary import necessary
1111

1212
from ..base import SingleBaseMapper, TransformElementType
1313

14+
with necessary("transformers", soft=True):
15+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
16+
1417
__all__ = [
1518
"PaddingMapper",
1619
"TokenizerMapper",
@@ -23,7 +26,7 @@ class GetTokenizerOutputFieldsAndNamesMixIn:
2326
"""A mixin class that figures out the output fields based on the arguments
2427
that will be passed a to tokenizer.__call__ method."""
2528

26-
tokenizer: PreTrainedTokenizerBase
29+
tokenizer: 'PreTrainedTokenizerBase'
2730
_prefix: Optional[str]
2831

2932
def __init__(
@@ -81,7 +84,7 @@ class TokenizerMapper(SingleBaseMapper, GetTokenizerOutputFieldsAndNamesMixIn):
8184

8285
def __init__(
8386
self,
84-
tokenizer: PreTrainedTokenizerBase,
87+
tokenizer: 'PreTrainedTokenizerBase',
8588
input_field: str,
8689
output_prefix: Optional[str] = None,
8790
output_rename_map: Optional[Dict[str, str]] = None,

src/smashed/recipes/prompting.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, Literal, Optional, Sequence, TypeVar, Union
22

3-
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
3+
from necessary import necessary
44

55
from ..base.mappers import ChainableMapperMixIn
66
from ..base.recipes import BaseRecipe
@@ -12,6 +12,9 @@
1212
)
1313
from ..mappers.shape import SingleSequenceStriderMapper
1414

15+
with necessary("transformers", soft=True):
16+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
17+
1518
C = TypeVar("C", bound=ChainableMapperMixIn)
1619

1720

@@ -34,7 +37,7 @@ def strider_mapper(self, **kwargs) -> SingleSequenceStriderMapper:
3437

3538
def __init__(
3639
self,
37-
tokenizer: PreTrainedTokenizerBase,
40+
tokenizer: 'PreTrainedTokenizerBase',
3841
source_template: str,
3942
source_add_bos_token: bool = True,
4043
source_add_eos_token: bool = False,
@@ -229,7 +232,7 @@ def _add_truncation_and_striding(
229232
self,
230233
pipeline: C,
231234
prompt_mapper: FillEncodedPromptMapper,
232-
tokenizer: PreTrainedTokenizerBase,
235+
tokenizer: 'PreTrainedTokenizerBase',
233236
all_fields_to_truncate: Sequence[str],
234237
all_fields_to_stride: Sequence[str],
235238
strategy: Union[Literal["longest"], Literal["uniform"]],

src/smashed/utils/version.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
def get_version() -> str:
66
"""Get the version of the package."""
7-
87
# This is a workaround for the fact that if the package is installed
98
# in editable mode, the version is not reliability available.
109
# Therefore, we check for the existence of a file called EDITABLE,
@@ -16,7 +15,7 @@ def get_version() -> str:
1615
try:
1716
# package has been installed, so it has a version number
1817
# from pyproject.toml
19-
version = importlib.metadata.version(__package__ or __name__)
18+
version = importlib.metadata.version(get_name())
2019
except importlib.metadata.PackageNotFoundError:
2120
# package hasn't been installed, so set version to "dev"
2221
version = "dev"
@@ -26,9 +25,7 @@ def get_version() -> str:
2625

2726
def get_name() -> str:
2827
"""Get the name of the package."""
29-
import smashed
30-
31-
return smashed.__package__ or smashed.__name__
28+
return 'smashed'
3229

3330

3431
def get_name_and_version() -> str:

0 commit comments

Comments
 (0)