Skip to content

Commit 11503d1

Browse files
yaozhaogooglecopybara-github
authored andcommitted
update gigaword
PiperOrigin-RevId: 283906111
1 parent 286d7a8 commit 11503d1

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

tensorflow_datasets/summarization/gigaword.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class Gigaword(tfds.core.GeneratorBasedBuilder):
6969

7070
# 1.0.0 contains a bug that uses validation data as training data.
7171
# 1.1.0 Update to the correct train, validation and test data.
72-
VERSION = tfds.core.Version("1.1.0")
72+
# 1.2.0 Replace <unk> with <UNK> in train/val to be consistent with test.
73+
VERSION = tfds.core.Version("1.2.0")
7374

7475
def _info(self):
7576
return tfds.core.DatasetInfo(
@@ -93,27 +94,36 @@ def _split_generators(self, dl_manager):
9394
name=tfds.Split.TRAIN,
9495
gen_kwargs={
9596
"src_path": pattern % ("train", "src"),
96-
"tgt_path": pattern % ("train", "tgt")
97+
"tgt_path": pattern % ("train", "tgt"),
98+
"replace_unk": True,
9799
},
98100
),
99101
tfds.core.SplitGenerator(
100102
name=tfds.Split.VALIDATION,
101103
gen_kwargs={
102104
"src_path": pattern % ("dev", "src"),
103-
"tgt_path": pattern % ("dev", "tgt")
105+
"tgt_path": pattern % ("dev", "tgt"),
106+
"replace_unk": True,
104107
},
105108
),
106109
tfds.core.SplitGenerator(
107110
name=tfds.Split.TEST,
108111
gen_kwargs={
109112
"src_path": pattern % ("test", "src"),
110-
"tgt_path": pattern % ("test", "tgt")
113+
"tgt_path": pattern % ("test", "tgt"),
114+
"replace_unk": False,
111115
},
112116
),
113117
]
114118

115-
def _generate_examples(self, src_path=None, tgt_path=None):
119+
def _generate_examples(self, src_path=None, tgt_path=None, replace_unk=None):
116120
"""Yields examples."""
117121
with tf.io.gfile.GFile(src_path) as f_d, tf.io.gfile.GFile(tgt_path) as f_s:
118122
for i, (doc_text, sum_text) in enumerate(zip(f_d, f_s)):
119-
yield i, {_DOCUMENT: doc_text.strip(), _SUMMARY: sum_text.strip()}
123+
if replace_unk:
124+
yield i, {
125+
_DOCUMENT: doc_text.strip().replace("<unk>", "UNK"),
126+
_SUMMARY: sum_text.strip().replace("<unk>", "UNK")
127+
}
128+
else:
129+
yield i, {_DOCUMENT: doc_text.strip(), _SUMMARY: sum_text.strip()}

0 commit comments

Comments
 (0)