Skip to content

Commit 830a663

Browse files
adarobcopybara-github
authored andcommitted
Add glue/mnli configs that include only the matched or mismatched test/validation sets.
PiperOrigin-RevId: 258498603
1 parent f7d1650 commit 830a663

File tree

1 file changed

+79
-66
lines changed

1 file changed

+79
-66
lines changed

tensorflow_datasets/text/glue.py

Lines changed: 79 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,42 @@
4343
_MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
4444
_MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
4545

46+
_MNLI_BASE_KWARGS = dict(
47+
text_features={
48+
"premise": "sentence1",
49+
"hypothesis": "sentence2",
50+
},
51+
label_classes=["entailment", "neutral", "contradiction"],
52+
label_column="gold_label",
53+
data_url="https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
54+
data_dir="MNLI",
55+
citation="""\
56+
@InProceedings{N18-1101,
57+
author = "Williams, Adina
58+
and Nangia, Nikita
59+
and Bowman, Samuel",
60+
title = "A Broad-Coverage Challenge Corpus for
61+
Sentence Understanding through Inference",
62+
booktitle = "Proceedings of the 2018 Conference of
63+
the North American Chapter of the
64+
Association for Computational Linguistics:
65+
Human Language Technologies, Volume 1 (Long
66+
Papers)",
67+
year = "2018",
68+
publisher = "Association for Computational Linguistics",
69+
pages = "1112--1122",
70+
location = "New Orleans, Louisiana",
71+
url = "http://aclweb.org/anthology/N18-1101"
72+
}
73+
@article{bowman2015large,
74+
title={A large annotated corpus for learning natural language inference},
75+
author={Bowman, Samuel R and Angeli, Gabor and Potts, Christopher and Manning, Christopher D},
76+
journal={arXiv preprint arXiv:1508.05326},
77+
year={2015}
78+
}""",
79+
url="http://www.nyu.edu/projects/bowman/multinli/",
80+
train_shards=2)
81+
4682

4783
class GlueConfig(tfds.core.BuilderConfig):
4884
"""BuilderConfig for GLUE."""
@@ -219,40 +255,19 @@ class Glue(tfds.core.GeneratorBasedBuilder):
219255
We use the standard test set, for which we obtained private labels from the authors, and evaluate
220256
on both the matched (in-domain) and mismatched (cross-domain) section. We also use and recommend
221257
the SNLI corpus as 550k examples of auxiliary training data.""",
222-
text_features={
223-
"premise": "sentence1",
224-
"hypothesis": "sentence2",
225-
},
226-
label_classes=["entailment", "neutral", "contradiction"],
227-
label_column="gold_label",
228-
data_url="https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
229-
data_dir="MNLI",
230-
citation="""\
231-
@InProceedings{N18-1101,
232-
author = "Williams, Adina
233-
and Nangia, Nikita
234-
and Bowman, Samuel",
235-
title = "A Broad-Coverage Challenge Corpus for
236-
Sentence Understanding through Inference",
237-
booktitle = "Proceedings of the 2018 Conference of
238-
the North American Chapter of the
239-
Association for Computational Linguistics:
240-
Human Language Technologies, Volume 1 (Long
241-
Papers)",
242-
year = "2018",
243-
publisher = "Association for Computational Linguistics",
244-
pages = "1112--1122",
245-
location = "New Orleans, Louisiana",
246-
url = "http://aclweb.org/anthology/N18-1101"
247-
}
248-
@article{bowman2015large,
249-
title={A large annotated corpus for learning natural language inference},
250-
author={Bowman, Samuel R and Angeli, Gabor and Potts, Christopher and Manning, Christopher D},
251-
journal={arXiv preprint arXiv:1508.05326},
252-
year={2015}
253-
}""",
254-
url="http://www.nyu.edu/projects/bowman/multinli/",
255-
train_shards=2),
258+
**_MNLI_BASE_KWARGS),
259+
GlueConfig(
260+
name="mnli_mismatched",
261+
description="""\
262+
The mismatched validation and test splits from MNLI.
263+
See the "mnli" BuilderConfig for additional information.""",
264+
**_MNLI_BASE_KWARGS),
265+
GlueConfig(
266+
name="mnli_matched",
267+
description="""\
268+
The matched validation and test splits from MNLI.
269+
See the "mnli" BuilderConfig for additional information.""",
270+
**_MNLI_BASE_KWARGS),
256271
GlueConfig(
257272
name="qnli",
258273
description="""\
@@ -414,38 +429,23 @@ def _split_generators(self, dl_manager):
414429
if self.builder_config.name == "mnli":
415430
return [
416431
train_split,
417-
tfds.core.SplitGenerator(
418-
name="validation_matched",
419-
num_shards=1,
420-
gen_kwargs={
421-
"data_file": os.path.join(data_dir, "dev_matched.tsv"),
422-
"split": "dev",
423-
"mrpc_files": None,
424-
}),
425-
tfds.core.SplitGenerator(
426-
name="validation_mismatched",
427-
num_shards=1,
428-
gen_kwargs={
429-
"data_file": os.path.join(data_dir, "dev_mismatched.tsv"),
430-
"split": "dev",
431-
"mrpc_files": None,
432-
}),
433-
tfds.core.SplitGenerator(
434-
name="test_matched",
435-
num_shards=1,
436-
gen_kwargs={
437-
"data_file": os.path.join(data_dir, "test_matched.tsv"),
438-
"split": "test",
439-
"mrpc_files": None,
440-
}),
441-
tfds.core.SplitGenerator(
442-
name="test_mismatched",
443-
num_shards=1,
444-
gen_kwargs={
445-
"data_file": os.path.join(data_dir, "test_mismatched.tsv"),
446-
"split": "test",
447-
"mrpc_files": None,
448-
}),
432+
_mnli_split_generator(
433+
"validation_matched", data_dir, "dev", matched=True),
434+
_mnli_split_generator(
435+
"validation_mismatched", data_dir, "dev", matched=False),
436+
_mnli_split_generator("test_matched", data_dir, "test", matched=True),
437+
_mnli_split_generator(
438+
"test_mismatched", data_dir, "test", matched=False)
439+
]
440+
elif self.builder_config.name == "mnli_matched":
441+
return [
442+
_mnli_split_generator("validation", data_dir, "dev", matched=True),
443+
_mnli_split_generator("test", data_dir, "test", matched=True)
444+
]
445+
elif self.builder_config.name == "mnli_mismatched":
446+
return [
447+
_mnli_split_generator("validation", data_dir, "dev", matched=False),
448+
_mnli_split_generator("test", data_dir, "test", matched=False)
449449
]
450450
else:
451451
return [
@@ -547,3 +547,16 @@ def _generate_example_mrpc_files(self, mrpc_files, split):
547547
"label": int(row["Quality"]),
548548
"idx": n,
549549
}
550+
551+
552+
def _mnli_split_generator(name, data_dir, split, matched):
553+
return tfds.core.SplitGenerator(
554+
name=name,
555+
num_shards=1,
556+
gen_kwargs={
557+
"data_file": os.path.join(
558+
data_dir,
559+
"%s_%s.tsv" % (split, "matched" if matched else "mismatched")),
560+
"split": split,
561+
"mrpc_files": None,
562+
})

0 commit comments

Comments
 (0)