Skip to content

Commit 1061cf5

Browse files
authored
Promptsource recipe (#40)
* no-op for recipe * no-op for recipe * prompting recipe and test * small error handling * documentation * documentation
1 parent 7e025ed commit 1061cf5

File tree

6 files changed

+377
-13
lines changed

6 files changed

+377
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "smashed"
3-
version = "0.13.0"
3+
version = "0.14.0"
44
description = "Sequential MAppers for Sequences of HEterogeneous Dictionaries is a set of Python interfaces designed to apply transformations to samples in datasets, which are often implemented as sequences of dictionaries."
55
authors = [
66
{name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org" },

src/smashed/mappers/promptsource.py

Lines changed: 146 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, Optional, cast
1+
from itertools import chain
2+
from typing import Any, Dict, List, Optional, Tuple, cast
23

34
from necessary import Necessary, necessary
45

@@ -30,6 +31,35 @@ def __init__(
3031
return_multiple_targets: bool = False,
3132
extra_variables: Optional[Dict[str, Any]] = None,
3233
):
34+
"""Uses a promptsource template to generate source and target sequence;
35+
in the returned dictionary of samples, the source sequence is stored
36+
under the key `source_field_name` and the target sequence is stored
37+
under the key `target_field_name`. If the template does not contain
38+
the control sequence `|||`, then no target sequence is generated.
39+
Args:
40+
template (promptsource.templates.Template): the promptsource
41+
template to use.
42+
source_field_name (str, optional): the name of the field in the
43+
returned dictionary of samples that will contain the source
44+
sequence. Defaults to "source".
45+
target_field_name (str, optional): the name of the field in the
46+
returned dictionary of samples that will contain the target
47+
sequence. Defaults to "target".
48+
truncate (bool, optional): whether to truncate the source and
49+
target sequences to the maximum length allowed by
50+
the promptsource library. Defaults to False.
51+
highlight_variables (bool, optional): whether to highlight the
52+
variables in the source and target sequences with special
53+
html tags. Defaults to False.
54+
return_multiple_targets (bool, optional): whether to return
55+
a list of target sequences for each sample. Defaults to False.
56+
If the template returns multiple targets, but this argument
57+
is set to False, then only the first target is returned.
58+
extra_variables (Optional[Dict[str, Any]], optional): a dictionary
59+
of extra variables that will be passed to the promptsource
60+
template. Defaults to None.
61+
"""
62+
3363
self.template = template
3464
self.truncate = truncate
3565
self.highlight_vars = highlight_variables
@@ -44,23 +74,65 @@ def __init__(
4474

4575
# abstract syntax tree for the jinja template; we will use it
4676
# to find all fields that are required by the template
47-
ast = Environment().parse(self.template.jinja)
48-
input_fields = sorted(
49-
var_name
50-
for var_name in meta.find_undeclared_variables(ast)
51-
if var_name not in self.extra_vars
52-
)
5377

5478
output_fields = [self.src_fld_name]
5579
if "|||" in self.template.jinja:
5680
output_fields.append(self.tgt_fld_name)
5781

82+
input_src_fields, input_tgt_fields = self.approximate_input_fields
5883
super().__init__(
59-
input_fields=input_fields, output_fields=output_fields
84+
input_fields=set(input_src_fields + input_tgt_fields),
85+
output_fields=output_fields,
86+
)
87+
88+
def _approximate_input_fields(self, jinja_txt: str) -> List[str]:
89+
ast = Environment().parse(jinja_txt)
90+
return sorted(
91+
var_name
92+
for var_name in meta.find_undeclared_variables(ast)
93+
if var_name not in self.extra_vars
94+
)
95+
96+
@property
97+
def approximate_input_fields(self) -> Tuple[List[str], List[str]]:
98+
"""Input fields that are likely to be required by the template;
99+
It is approximate because we ignore nested variables."""
100+
101+
source_template, *target_templates = self.template.jinja.split("|||")
102+
source_fields = self._approximate_input_fields(source_template)
103+
target_fields = sorted(
104+
set(
105+
chain.from_iterable(
106+
self._approximate_input_fields(template)
107+
for template in target_templates
108+
)
109+
)
60110
)
111+
return source_fields, target_fields
112+
113+
def _approximate_text_from_template(self, txt: str) -> str:
114+
return "".join(part.split("}}")[-1] for part in txt.split("{{"))
115+
116+
@property
117+
def approximate_prompt_text(self) -> Tuple[str, List[str]]:
118+
"""The prompt without the variables; it is approximate because
119+
we might not be able to remove all variables."""
120+
121+
source_template, *target_templates = self.template.jinja.split("|||")
122+
123+
source_str = self._approximate_text_from_template(source_template)
124+
target_str = [
125+
self._approximate_text_from_template(template)
126+
for template in target_templates
127+
]
128+
return source_str, target_str
129+
130+
@property
131+
def has_target(self) -> bool:
132+
return "|||" in self.template.jinja
61133

62134
def __getstate__(self) -> dict:
63-
"""We need to serialize the template using yaml so the hash for this
135+
"""We need to serialize thve template using yaml so the hash for this
64136
mapper is consistent across runs."""
65137
out = super().__getstate__()
66138
out["__dict__"]["template"] = yaml.dump(self.template)
@@ -113,6 +185,37 @@ def __init__(
113185
return_multiple_targets: bool = False,
114186
extra_variables: Optional[Dict[str, Any]] = None,
115187
):
188+
"""Use one of the existing promptsource templates to generate
189+
source and target sequences for a dataset. See the promptsource
190+
repository for a list of available templates:
191+
https://github.com/bigscience-workshop/promptsource
192+
193+
Args:
194+
dataset_name (str): the name of the dataset to use.
195+
template_name (str): the name of the template to use.
196+
subset_name (Optional[str], optional): the name of the subset
197+
to use. Defaults to None.
198+
source_field_name (str, optional): the name of the field in the
199+
returned dictionary of samples that will contain the source
200+
sequence. Defaults to "source".
201+
target_field_name (str, optional): the name of the field in the
202+
returned dictionary of samples that will contain the target
203+
sequence. Defaults to "target".
204+
truncate (bool, optional): whether to truncate the source and
205+
target sequences to the maximum length allowed by
206+
the promptsource library. Defaults to False.
207+
highlight_variables (bool, optional): whether to highlight the
208+
variables in the source and target sequences with special
209+
html tags. Defaults to False.
210+
return_multiple_targets (bool, optional): whether to return
211+
a list of target sequences for each sample. Defaults to False.
212+
If the template returns multiple targets, but this argument
213+
is set to False, then only the first target is returned.
214+
extra_variables (Optional[Dict[str, Any]], optional): a dictionary
215+
of extra variables that will be passed to the promptsource
216+
template. Defaults to None.
217+
"""
218+
116219
# DatasetTemplates is not well annotated, so though subset_name
117220
# is optional, it is annotated as `str`, so we need to cast it.
118221
subset_name = cast(str, subset_name)
@@ -151,6 +254,40 @@ def __init__(
151254
return_multiple_targets: bool = False,
152255
extra_variables: Optional[Dict[str, Any]] = None,
153256
):
257+
"""Use a custom jinja template to obtain a template from the
258+
promptsource library. See the jinja documentation for a list of
259+
language features and syntax: https://jinja.palletsprojects.com/
260+
261+
Args:
262+
jinja (str): the jinja template to use. The template can access
263+
the data in each sample; the name of fields in the datasets
264+
are available as variables in the template.
265+
name (Optional[str], optional): the name of the template. Defaults
266+
to None.
267+
reference (Optional[str], optional): the reference for the
268+
template. Defaults to None.
269+
metadata (Optional["Template.Metadata"], optional): the metadata
270+
for the template. Defaults to None.
271+
source_field_name (str, optional): the name of the field in the
272+
returned dictionary of samples that will contain the source
273+
sequence. Defaults to "source".
274+
target_field_name (str, optional): the name of the field in the
275+
returned dictionary of samples that will contain the target
276+
sequence. Defaults to "target".
277+
truncate (bool, optional): whether to truncate the source and
278+
target sequences to the maximum length allowed by
279+
the promptsource library. Defaults to False.
280+
highlight_variables (bool, optional): whether to highlight the
281+
variables in the source and target sequences with special
282+
html tags. Defaults to False.
283+
return_multiple_targets (bool, optional): whether to return
284+
a list of target sequences for each sample. Defaults to False.
285+
If the template returns multiple targets, but this argument
286+
is set to False, then only the first target is returned.
287+
extra_variables (Optional[Dict[str, Any]], optional): a dictionary
288+
of extra variables that will be passed to the promptsource
289+
template. Defaults to None.
290+
"""
154291
template = Template(
155292
jinja=jinja,
156293
name=(name or self.name),

src/smashed/recipes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .collators import CollatorRecipe, SlowCollatorRecipe
22
from .prompting import PromptingRecipe
3+
from .promptsource import PromptsourceRecipe
34

45
__all__ = [
56
"CollatorRecipe",
67
"PromptingRecipe",
8+
"PromptsourceRecipe",
79
"SlowCollatorRecipe",
810
]

src/smashed/recipes/collators.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
22

3+
import torch
34
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
45

56
from ..base import BaseRecipe, SingleBaseMapper
@@ -46,9 +47,13 @@ def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
4647

4748
return collated_batch
4849

49-
def get_tensorizer(self) -> Python2TorchMapper:
50+
def get_tensorizer(
51+
self,
52+
field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None,
53+
device: Optional[Union[torch.device, str]] = None,
54+
) -> Python2TorchMapper:
5055
# this turns lists of ints/floats into tensors
51-
return Python2TorchMapper()
56+
return Python2TorchMapper(field_cast_map=field_cast_map, device=device)
5257

5358
def get_batcher(self, keep_last: bool) -> FixedBatchSizeMapper:
5459
# the collator already receives the "right" number of samples
@@ -66,10 +71,14 @@ def __init__(
6671
pad_to_length: Optional[Union[int, Sequence[int]]] = None,
6772
fields_pad_ids: Optional[Mapping[str, int]] = None,
6873
unk_fields_pad_id: Optional[int] = None,
74+
field_cast_map: Optional[Mapping[str, Union[str, torch.dtype]]] = None,
75+
device: Optional[Union[torch.device, str]] = None,
6976
) -> None:
7077
super().__init__(do_not_collate=do_not_collate)
7178

72-
self.chain(self.get_tensorizer())
79+
self.chain(
80+
self.get_tensorizer(field_cast_map=field_cast_map, device=device)
81+
)
7382
self.chain(self.get_batcher(keep_last=keep_last))
7483

7584
if tokenizer:

0 commit comments

Comments
 (0)