Skip to content

Commit 7e9ab95

Browse files
peterjliucopybara-github
authored andcommitted
Add CNN/DailyMail version with target sentences separated by newline.
PiperOrigin-RevId: 282696546
1 parent aa4178f commit 7e9ab95

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

tensorflow_datasets/summarization/cnn_dailymail.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@
7373

7474
_HIGHLIGHTS = 'highlights'
7575
_ARTICLE = 'article'
76+
_SUPPORTED_VERSIONS = [
77+
tfds.core.Version('0.0.2', experiments={tfds.core.Experiment.S3: False}),
78+
# Same data as 0.0.2
79+
tfds.core.Version('1.0.0',
80+
'New split API (https://tensorflow.org/datasets/splits)')]
81+
82+
# Having the model predict newline separators makes it easier to evaluate
83+
# using summary-level ROUGE.
84+
_DEFAULT_VERSION = tfds.core.Version('2.0.0',
85+
'Separate target sentences with newline.')
7686

7787

7888
class CnnDailymailConfig(tfds.core.BuilderConfig):
@@ -92,13 +102,8 @@ def __init__(self, text_encoder_config=None, **kwargs):
92102
# 1.0.0: S3 (new shuffling, sharding and slicing mechanism).
93103
# 0.0.2: Initial version.
94104
super(CnnDailymailConfig, self).__init__(
95-
version=tfds.core.Version(
96-
'0.0.2', experiments={tfds.core.Experiment.S3: False}),
97-
supported_versions=[
98-
tfds.core.Version(
99-
'1.0.0',
100-
'New split API (https://tensorflow.org/datasets/splits)'),
101-
],
105+
version=_DEFAULT_VERSION,
106+
supported_versions=_SUPPORTED_VERSIONS,
102107
**kwargs)
103108
self.text_encoder_config = (
104109
text_encoder_config or tfds.features.text.TextEncoderConfig())
@@ -168,7 +173,7 @@ def _read_text_file(text_file):
168173
return lines
169174

170175

171-
def _get_art_abs(story_file):
176+
def _get_art_abs(story_file, tfds_version):
172177
"""Get abstract (highlights) and article from a story file path."""
173178
# Based on https://github.com/abisee/cnn-dailymail/blob/master/
174179
# make_datafiles.py
@@ -207,16 +212,16 @@ def fix_missing_period(line):
207212
# Make article into a single string
208213
article = ' '.join(article_lines)
209214

210-
# Make abstract into a single string, putting <s> and </s> tags around
211-
# the sentences.
212-
abstract = ' '.join(highlights)
215+
if tfds_version >= '2.0.0':
216+
abstract = '\n'.join(highlights)
217+
else:
218+
abstract = ' '.join(highlights)
213219

214220
return article, abstract
215221

216222

217223
class CnnDailymail(tfds.core.GeneratorBasedBuilder):
218224
"""CNN/DailyMail non-anonymized summarization dataset."""
219-
# 0.0.2 is like 0.0.1 but without special tokens <s> and </s>.
220225
BUILDER_CONFIGS = [
221226
CnnDailymailConfig(
222227
name='plain_text',
@@ -291,7 +296,7 @@ def _split_generators(self, dl_manager):
291296

292297
def _generate_examples(self, files):
293298
for p in files:
294-
article, highlights = _get_art_abs(p)
299+
article, highlights = _get_art_abs(p, self.version)
295300
if not article or not highlights:
296301
continue
297302
fname = os.path.basename(p)

tensorflow_datasets/summarization/cnn_dailymail_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tempfile
2222

2323
from tensorflow_datasets import testing
24+
import tensorflow_datasets.public_api as tfds
2425
from tensorflow_datasets.summarization import cnn_dailymail
2526

2627
_STORY_FILE = b"""Some article.
@@ -55,14 +56,20 @@ def test_get_art_abs(self):
5556
with tempfile.NamedTemporaryFile(delete=True) as f:
5657
f.write(_STORY_FILE)
5758
f.flush()
58-
article, abstract = cnn_dailymail._get_art_abs(f.name)
59+
article, abstract = cnn_dailymail._get_art_abs(
60+
f.name, tfds.core.Version('1.0.0'))
5961
self.assertEqual('some article. this is some article text.', article)
6062
# This is a bit weird, but the original code at
6163
# https://github.com/abisee/cnn-dailymail/ adds space before period
6264
# for abstracts and we retain this behavior.
6365
self.assertEqual('highlight text . highlight two . highlight three .',
6466
abstract)
6567

68+
article, abstract = cnn_dailymail._get_art_abs(f.name,
69+
tfds.core.Version('2.0.0'))
70+
self.assertEqual('highlight text .\nhighlight two .\nhighlight three .',
71+
abstract)
72+
6673

6774
class CnnDailymailS3Test(CnnDailymailTest):
6875
VERSION = 'experimental_latest'

0 commit comments

Comments
 (0)