1
- from typing import Any , Dict , Optional , cast
1
+ from itertools import chain
2
+ from typing import Any , Dict , List , Optional , Tuple , cast
2
3
3
4
from necessary import Necessary , necessary
4
5
@@ -30,6 +31,35 @@ def __init__(
30
31
return_multiple_targets : bool = False ,
31
32
extra_variables : Optional [Dict [str , Any ]] = None ,
32
33
):
34
+ """Uses a promptsource template to generate source and target sequence;
35
+ in the returned dictionary of samples, the source sequence is stored
36
+ under the key `source_field_name` and the target sequence is stored
37
+ under the key `target_field_name`. If the template does not contain
38
+ the control sequence `|||`, then no target sequence is generated.
39
+ Args:
40
+ template (promptsource.templates.Template): the promptsource
41
+ template to use.
42
+ source_field_name (str, optional): the name of the field in the
43
+ returned dictionary of samples that will contain the source
44
+ sequence. Defaults to "source".
45
+ target_field_name (str, optional): the name of the field in the
46
+ returned dictionary of samples that will contain the target
47
+ sequence. Defaults to "target".
48
+ truncate (bool, optional): whether to truncate the source and
49
+ target sequences to the maximum length allowed by
50
+ the promptsource library. Defaults to False.
51
+ highlight_variables (bool, optional): whether to highlight the
52
+ variables in the source and target sequences with special
53
+ html tags. Defaults to False.
54
+ return_multiple_targets (bool, optional): whether to return
55
+ a list of target sequences for each sample. Defaults to False.
56
+ If the template returns multiple targets, but this argument
57
+ is set to False, then only the first target is returned.
58
+ extra_variables (Optional[Dict[str, Any]], optional): a dictionary
59
+ of extra variables that will be passed to the promptsource
60
+ template. Defaults to None.
61
+ """
62
+
33
63
self .template = template
34
64
self .truncate = truncate
35
65
self .highlight_vars = highlight_variables
@@ -44,23 +74,65 @@ def __init__(
44
74
45
75
# abstract syntax tree for the jinja template; we will use it
46
76
# to find all fields that are required by the template
47
- ast = Environment ().parse (self .template .jinja )
48
- input_fields = sorted (
49
- var_name
50
- for var_name in meta .find_undeclared_variables (ast )
51
- if var_name not in self .extra_vars
52
- )
53
77
54
78
output_fields = [self .src_fld_name ]
55
79
if "|||" in self .template .jinja :
56
80
output_fields .append (self .tgt_fld_name )
57
81
82
+ input_src_fields , input_tgt_fields = self .approximate_input_fields
58
83
super ().__init__ (
59
- input_fields = input_fields , output_fields = output_fields
84
+ input_fields = set (input_src_fields + input_tgt_fields ),
85
+ output_fields = output_fields ,
86
+ )
87
+
88
+ def _approximate_input_fields (self , jinja_txt : str ) -> List [str ]:
89
+ ast = Environment ().parse (jinja_txt )
90
+ return sorted (
91
+ var_name
92
+ for var_name in meta .find_undeclared_variables (ast )
93
+ if var_name not in self .extra_vars
94
+ )
95
+
96
+ @property
97
+ def approximate_input_fields (self ) -> Tuple [List [str ], List [str ]]:
98
+ """Input fields that are likely to be required by the template;
99
+ It is approximate because we ignore nested variables."""
100
+
101
+ source_template , * target_templates = self .template .jinja .split ("|||" )
102
+ source_fields = self ._approximate_input_fields (source_template )
103
+ target_fields = sorted (
104
+ set (
105
+ chain .from_iterable (
106
+ self ._approximate_input_fields (template )
107
+ for template in target_templates
108
+ )
109
+ )
60
110
)
111
+ return source_fields , target_fields
112
+
113
+ def _approximate_text_from_template (self , txt : str ) -> str :
114
+ return "" .join (part .split ("}}" )[- 1 ] for part in txt .split ("{{" ))
115
+
116
+ @property
117
+ def approximate_prompt_text (self ) -> Tuple [str , List [str ]]:
118
+ """The prompt without the variables; it is approximate because
119
+ we might not be able to remove all variables."""
120
+
121
+ source_template , * target_templates = self .template .jinja .split ("|||" )
122
+
123
+ source_str = self ._approximate_text_from_template (source_template )
124
+ target_str = [
125
+ self ._approximate_text_from_template (template )
126
+ for template in target_templates
127
+ ]
128
+ return source_str , target_str
129
+
130
+ @property
131
+ def has_target (self ) -> bool :
132
+ return "|||" in self .template .jinja
61
133
62
134
def __getstate__ (self ) -> dict :
63
- """We need to serialize the template using yaml so the hash for this
135
+ """We need to serialize thve template using yaml so the hash for this
64
136
mapper is consistent across runs."""
65
137
out = super ().__getstate__ ()
66
138
out ["__dict__" ]["template" ] = yaml .dump (self .template )
@@ -113,6 +185,37 @@ def __init__(
113
185
return_multiple_targets : bool = False ,
114
186
extra_variables : Optional [Dict [str , Any ]] = None ,
115
187
):
188
+ """Use one of the existing promptsource templates to generate
189
+ source and target sequences for a dataset. See the promptsource
190
+ repository for a list of available templates:
191
+ https://github.com/bigscience-workshop/promptsource
192
+
193
+ Args:
194
+ dataset_name (str): the name of the dataset to use.
195
+ template_name (str): the name of the template to use.
196
+ subset_name (Optional[str], optional): the name of the subset
197
+ to use. Defaults to None.
198
+ source_field_name (str, optional): the name of the field in the
199
+ returned dictionary of samples that will contain the source
200
+ sequence. Defaults to "source".
201
+ target_field_name (str, optional): the name of the field in the
202
+ returned dictionary of samples that will contain the target
203
+ sequence. Defaults to "target".
204
+ truncate (bool, optional): whether to truncate the source and
205
+ target sequences to the maximum length allowed by
206
+ the promptsource library. Defaults to False.
207
+ highlight_variables (bool, optional): whether to highlight the
208
+ variables in the source and target sequences with special
209
+ html tags. Defaults to False.
210
+ return_multiple_targets (bool, optional): whether to return
211
+ a list of target sequences for each sample. Defaults to False.
212
+ If the template returns multiple targets, but this argument
213
+ is set to False, then only the first target is returned.
214
+ extra_variables (Optional[Dict[str, Any]], optional): a dictionary
215
+ of extra variables that will be passed to the promptsource
216
+ template. Defaults to None.
217
+ """
218
+
116
219
# DatasetTemplates is not well annotated, so though subset_name
117
220
# is optional, it is annotated as `str`, so we need to cast it.
118
221
subset_name = cast (str , subset_name )
@@ -151,6 +254,40 @@ def __init__(
151
254
return_multiple_targets : bool = False ,
152
255
extra_variables : Optional [Dict [str , Any ]] = None ,
153
256
):
257
+ """Use a custom jinja template to obtain a template from the
258
+ promptsource library. See the jinja documentation for a list of
259
+ language features and syntax: https://jinja.palletsprojects.com/
260
+
261
+ Args:
262
+ jinja (str): the jinja template to use. The template can access
263
+ the data in each sample; the name of fields in the datasets
264
+ are available as variables in the template.
265
+ name (Optional[str], optional): the name of the template. Defaults
266
+ to None.
267
+ reference (Optional[str], optional): the reference for the
268
+ template. Defaults to None.
269
+ metadata (Optional["Template.Metadata"], optional): the metadata
270
+ for the template. Defaults to None.
271
+ source_field_name (str, optional): the name of the field in the
272
+ returned dictionary of samples that will contain the source
273
+ sequence. Defaults to "source".
274
+ target_field_name (str, optional): the name of the field in the
275
+ returned dictionary of samples that will contain the target
276
+ sequence. Defaults to "target".
277
+ truncate (bool, optional): whether to truncate the source and
278
+ target sequences to the maximum length allowed by
279
+ the promptsource library. Defaults to False.
280
+ highlight_variables (bool, optional): whether to highlight the
281
+ variables in the source and target sequences with special
282
+ html tags. Defaults to False.
283
+ return_multiple_targets (bool, optional): whether to return
284
+ a list of target sequences for each sample. Defaults to False.
285
+ If the template returns multiple targets, but this argument
286
+ is set to False, then only the first target is returned.
287
+ extra_variables (Optional[Dict[str, Any]], optional): a dictionary
288
+ of extra variables that will be passed to the promptsource
289
+ template. Defaults to None.
290
+ """
154
291
template = Template (
155
292
jinja = jinja ,
156
293
name = (name or self .name ),
0 commit comments