Skip to content

Commit 52ba3c5

Browse files
adarobcopybara-github
authored andcommitted
Include all answers in SQUAD. Requires removal of text encoder configs.
PiperOrigin-RevId: 249327504
1 parent e0cad3d commit 52ba3c5

File tree

1 file changed

+19
-67
lines changed

1 file changed

+19
-67
lines changed

tensorflow_datasets/text/squad.py

Lines changed: 19 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,13 @@ class SquadConfig(tfds.core.BuilderConfig):
5353
"""BuilderConfig for SQUAD."""
5454

5555
@api_utils.disallow_positional_args
56-
def __init__(self, text_encoder_config=None, **kwargs):
56+
def __init__(self, **kwargs):
5757
"""BuilderConfig for SQUAD.
5858
5959
Args:
60-
text_encoder_config: `tfds.features.text.TextEncoderConfig`, configuration
61-
for the `tfds.features.text.TextEncoder` used for the features feature.
6260
**kwargs: keyword arguments forwarded to super.
6361
"""
6462
super(SquadConfig, self).__init__(**kwargs)
65-
self.text_encoder_config = (
66-
text_encoder_config or tfds.features.text.TextEncoderConfig())
6763

6864

6965
class Squad(tfds.core.GeneratorBasedBuilder):
@@ -75,51 +71,29 @@ class Squad(tfds.core.GeneratorBasedBuilder):
7571
BUILDER_CONFIGS = [
7672
SquadConfig(
7773
name="plain_text",
78-
version="0.0.1",
74+
version="0.1.0",
7975
description="Plain text",
8076
),
81-
SquadConfig(
82-
name="bytes",
83-
version="0.0.1",
84-
description=("Uses byte-level text encoding with "
85-
"`tfds.features.text.ByteTextEncoder`"),
86-
text_encoder_config=tfds.features.text.TextEncoderConfig(
87-
encoder=tfds.features.text.ByteTextEncoder()),
88-
),
89-
SquadConfig(
90-
name="subwords8k",
91-
version="0.0.1",
92-
description=("Uses `tfds.features.text.SubwordTextEncoder` with 8k "
93-
"vocab size"),
94-
text_encoder_config=tfds.features.text.TextEncoderConfig(
95-
encoder_cls=tfds.features.text.SubwordTextEncoder,
96-
vocab_size=2**13),
97-
),
98-
SquadConfig(
99-
name="subwords32k",
100-
version="0.0.2",
101-
description=("Uses `tfds.features.text.SubwordTextEncoder` with "
102-
"32k vocab size"),
103-
text_encoder_config=tfds.features.text.TextEncoderConfig(
104-
encoder_cls=tfds.features.text.SubwordTextEncoder,
105-
vocab_size=2**15),
106-
),
10777
]
10878

10979
def _info(self):
11080
return tfds.core.DatasetInfo(
11181
builder=self,
11282
description=_DESCRIPTION,
11383
features=tfds.features.FeaturesDict({
84+
"id":
85+
tf.string,
86+
"title":
87+
tfds.features.Text(),
11488
"context":
115-
tfds.features.Text(
116-
encoder_config=self.builder_config.text_encoder_config),
89+
tfds.features.Text(),
11790
"question":
118-
tfds.features.Text(
119-
encoder_config=self.builder_config.text_encoder_config),
120-
"first_answer":
121-
tfds.features.Text(
122-
encoder_config=self.builder_config.text_encoder_config),
91+
tfds.features.Text(),
92+
"answers":
93+
tfds.features.Sequence({
94+
"text": tfds.features.Text(),
95+
"answer_start": tf.int32,
96+
}),
12397
}),
12498
# No default supervised_keys (as we have to pass both question
12599
# and context as input).
@@ -128,28 +102,13 @@ def _info(self):
128102
citation=_CITATION,
129103
)
130104

131-
def _vocab_text_gen(self, filepath):
132-
for ex in self._generate_examples(filepath):
133-
# "first_answer" is a substring of "context" so not need to add it here
134-
yield " ".join([ex["question"], ex["context"]])
135-
136105
def _split_generators(self, dl_manager):
137106
urls_to_download = {
138107
"train": os.path.join(self._URL, self._TRAINING_FILE),
139108
"dev": os.path.join(self._URL, self._DEV_FILE)
140109
}
141110
downloaded_files = dl_manager.download_and_extract(urls_to_download)
142111

143-
# Generate shared vocabulary
144-
# maybe_build_from_corpus uses SubwordTextEncoder if that's configured
145-
self.info.features["context"].maybe_build_from_corpus(
146-
self._vocab_text_gen(downloaded_files["train"]))
147-
encoder = self.info.features["context"].encoder
148-
# Use maybe_set_encoder because the encoder may have been restored from
149-
# package data.
150-
self.info.features["question"].maybe_set_encoder(encoder)
151-
self.info.features["first_answer"].maybe_set_encoder(encoder)
152-
153112
return [
154113
tfds.core.SplitGenerator(
155114
name=tfds.Split.TRAIN,
@@ -167,10 +126,7 @@ def _generate_examples(self, filepath):
167126
with tf.io.gfile.GFile(filepath) as f:
168127
squad = json.load(f)
169128
for article in squad["data"]:
170-
if "title" in article:
171-
title = article["title"].strip()
172-
else:
173-
title = ""
129+
title = article.get("title", "").strip()
174130
for paragraph in article["paragraphs"]:
175131
context = paragraph["context"].strip()
176132
for qa in paragraph["qas"]:
@@ -182,17 +138,13 @@ def _generate_examples(self, filepath):
182138

183139
# Features currently used are "context", "question", and "answers".
184140
# Others are extracted here for the ease of future expansions.
185-
example = {
141+
yield {
186142
"title": title,
187143
"context": context,
188144
"question": question,
189145
"id": id_,
190-
"answer_starts": answer_starts,
191-
"answers": answers,
192-
}
193-
yield {
194-
"question": example["question"],
195-
# TODO(b/121176753): return all the answers.
196-
"first_answer": example["answers"][0],
197-
"context": example["context"]
146+
"answers": {
147+
"answer_start": answer_starts,
148+
"text": answers,
149+
},
198150
}

0 commit comments

Comments
 (0)