Skip to content

Commit 376f6b3

Browse files
authored
torch optional (#59)
1 parent b0faa0f commit 376f6b3

File tree

7 files changed

+44
-24
lines changed

7 files changed

+44
-24
lines changed

pyproject.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.20.0"
3+
version = "0.21.0"
44
description = """\
55
SMASHED is a toolkit designed to apply transformations to samples in \
66
datasets, such as fields extraction, tokenization, prompting, batching, \
@@ -11,7 +11,6 @@ license = {text = "Apache-2.0"}
1111
readme = "README.md"
1212
requires-python = ">=3.8"
1313
dependencies = [
14-
"torch>=1.9",
1514
"necessary>=0.4.1",
1615
"trouting>=0.3.3",
1716
"ftfy>=6.1.1",
@@ -103,12 +102,17 @@ remote = [
103102
"smart-open>=5.2.1",
104103
"boto3>=1.25.5",
105104
]
105+
torch = [
106+
"torch>=1.9",
107+
]
106108
datasets = [
109+
"smashed[torch]",
107110
"transformers>=4.5",
108111
"datasets>=2.8.0",
109112
"dill>=0.3.0",
110113
]
111114
prompting = [
115+
"smashed[torch]",
112116
"transformers>=4.5",
113117
"promptsource>=0.2.3",
114118
"blingfire>=0.1.8",
@@ -119,6 +123,7 @@ torchdata = [
119123
]
120124
all = [
121125
"smashed[dev]",
126+
"smashed[torch]",
122127
"smashed[datasets]",
123128
"smashed[torchdata]",
124129
"smashed[remote]",

src/smashed/base/interfaces.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
)
1616

1717
from necessary import necessary
18-
from torch._utils import classproperty
1918
from trouting import trouting
2019

2120
from .abstract import (
@@ -52,7 +51,7 @@ class MapMethodInterfaceMixIn(AbstractBaseMapper):
5251
and various interfaces. Do not inherit from this class directly,
5352
but use SingleBaseMapper/BatchedBaseMapper instead."""
5453

55-
@classproperty
54+
@classmethod
5655
def always_remove_columns(cls) -> bool:
5756
"""Whether this mapper should always remove its input columns
5857
from the dataset. If False, the mapper will only remove columns
@@ -191,7 +190,7 @@ def _map_list_of_dicts(
191190
# TODO[lucas]: maybe support specifying which fields to keep?
192191
remove_columns = (
193192
bool(map_kwargs.get("remove_columns", False))
194-
or self.always_remove_columns
193+
or self.always_remove_columns()
195194
)
196195

197196
if isinstance(dataset, abc.Sequence):
@@ -258,7 +257,7 @@ def _map_huggingface_dataset(
258257

259258
print_fingerprint = map_kwargs.pop("print_fingerprint", False)
260259

261-
if self.always_remove_columns:
260+
if self.always_remove_columns():
262261
remove_columns = list(dataset.features.keys())
263262
else:
264263
remove_columns = map_kwargs.get("remove_columns", [])
@@ -320,7 +319,7 @@ def _map_huggingface_dataset_batch(
320319
# TODO[lucas]: maybe support specifying which fields to keep?
321320
remove_columns = (
322321
bool(map_kwargs.get("remove_columns", False))
323-
or self.always_remove_columns
322+
or self.always_remove_columns()
324323
)
325324

326325
dtview: DataBatchView[LazyBatch, str, Any] = DataBatchView(dataset)

src/smashed/mappers/collators.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Union,
1515
)
1616

17-
import torch
1817
from necessary import necessary
1918

2019
from ..base import SingleBaseMapper, TransformElementType
@@ -26,6 +25,10 @@
2625
PreTrainedTokenizerBase,
2726
)
2827

28+
with necessary("torch", soft=True) as PYTORCH_AVAILABLE:
29+
if PYTORCH_AVAILABLE or TYPE_CHECKING:
30+
import torch
31+
2932

3033
__all__ = [
3134
"ListCollatorMapper",
@@ -170,15 +173,21 @@ class TensorCollatorMapper(BaseCollator, SingleBaseMapper):
170173
>>> data_loader = DataLoader(..., collate_fn=collator.transform)
171174
"""
172175

176+
def __init__(self, *args, **kwargs):
177+
if not PYTORCH_AVAILABLE:
178+
cls_name = self.__class__.__name__
179+
raise ImportError(f"Pytorch is required to use {cls_name}")
180+
super().__init__(*args, **kwargs)
181+
173182
@staticmethod
174183
def _pad(
175-
sequence: Sequence[torch.Tensor],
184+
sequence: Sequence["torch.Tensor"],
176185
pad_value: Union[int, float],
177186
dim: int = 0,
178187
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
179188
pad_to_multiple_of: Optional[int] = None,
180189
right_pad: bool = True,
181-
) -> torch.Tensor:
190+
) -> "torch.Tensor":
182191
"""Pad a sequence of tensors to the same length.
183192
184193
Args:
@@ -272,8 +281,8 @@ def _pad(
272281
return torch.cat(to_stack, dim=dim)
273282

274283
def transform( # type: ignore
275-
self: "TensorCollatorMapper", data: Dict[str, Sequence[torch.Tensor]]
276-
) -> Dict[str, torch.Tensor]:
284+
self: "TensorCollatorMapper", data: Dict[str, Sequence["torch.Tensor"]]
285+
) -> Dict[str, "torch.Tensor"]:
277286
collated_data = {
278287
field_name: self._pad(
279288
sequence=list_of_tensors,

src/smashed/mappers/converters.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, TypeVar, Union
22

3-
import torch
43
from necessary import necessary
54
from trouting import trouting
65

@@ -15,16 +14,22 @@
1514
"HuggingFaceDataset", Dataset, IterableDataset
1615
)
1716

17+
with necessary("torch", soft=True) as PYTORCH_AVAILABLE:
18+
if PYTORCH_AVAILABLE or TYPE_CHECKING:
19+
import torch
20+
1821

1922
class Python2TorchMapper(SingleBaseMapper):
2023
__slots__ = ["field_cast_map", "device"]
21-
field_cast_map: Dict[str, torch.dtype]
22-
device: Union[torch.device, None]
24+
field_cast_map: Dict[str, "torch.dtype"]
25+
device: Union["torch.device", None]
2326

2427
def __init__(
2528
self: "Python2TorchMapper",
26-
field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None,
27-
device: Optional[Union[torch.device, str]] = None,
29+
field_cast_map: Optional[
30+
Mapping[str, Union[str, "torch.dtype"]]
31+
] = None,
32+
device: Optional[Union["torch.device", str]] = None,
2833
) -> None:
2934
"""Mapper that converts Python types to Torch types. It can optionally
3035
cast the values of a field to a specific type, and move to a specific
@@ -37,6 +42,10 @@ def __init__(
3742
device (Union[torch.device, str], optional): Device to move the
3843
tensors to. Defaults to None, which means no moving occurs.
3944
"""
45+
if not PYTORCH_AVAILABLE:
46+
cls_name = self.__class__.__name__
47+
raise ImportError(f"{cls_name} requires PyTorch to be installed")
48+
4049
self.device = torch.device(device) if device else None
4150

4251
self.field_cast_map = {
@@ -49,7 +58,7 @@ def __init__(
4958
)
5059

5160
@staticmethod
52-
def _get_dtype(dtype: Any) -> torch.dtype:
61+
def _get_dtype(dtype: Any) -> "torch.dtype":
5362
if isinstance(dtype, str):
5463
dtype = getattr(torch, dtype, None)
5564
if dtype is None:
@@ -102,7 +111,7 @@ def __init__(self: "Torch2PythonMapper") -> None:
102111
super().__init__()
103112

104113
def transform( # type: ignore
105-
self: "Torch2PythonMapper", data: Dict[str, torch.Tensor]
114+
self: "Torch2PythonMapper", data: Dict[str, "torch.Tensor"]
106115
) -> TransformElementType:
107116
return {
108117
field_name: field_value.cpu().tolist()

src/smashed/mappers/fields.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
22

33
from necessary import necessary
4-
from torch._utils import classproperty
54

65
from ..base import SingleBaseMapper, TransformElementType
76

@@ -19,7 +18,7 @@ class ChangeFieldsMapper(SingleBaseMapper):
1918
"""Mapper that removes some of the fields in a dataset.
2019
Either `keep_fields` or `drop_fields` must be specified, but not both."""
2120

22-
@classproperty
21+
@classmethod
2322
def always_remove_columns(cls) -> bool:
2423
return True
2524

@@ -71,7 +70,7 @@ def transform(self, data: TransformElementType) -> TransformElementType:
7170
class RenameFieldsMapper(SingleBaseMapper):
7271
"""Mapper that renames some of the fields batch"""
7372

74-
@classproperty
73+
@classmethod
7574
def always_remove_columns(cls) -> bool:
7675
return True
7776

src/smashed/mappers/glom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ExtendGlommerMixin:
2323

2424
def __getstate__(self):
2525
state = super().__getstate__() # pyright: ignore
26-
state["__dict__"].pop("glommer", None)
26+
state["__dict__"].pop("glommer", None) # pyright: ignore
2727
return state
2828

2929
@cached_property

src/smashed/utils/io_utils/compression.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def compress_stream(
4848
errors: str = "strict",
4949
gzip: bool = True,
5050
) -> Iterator[IO]:
51-
5251
assert gzip, "Only gzip compression is supported at this time"
5352

5453
if mode == "wb" or mode == "w":

0 commit comments

Comments
 (0)