@@ -66,14 +66,7 @@ def _split_generators(self, dl_manager):
66
66
),
67
67
]
68
68
69
- def _build_pcollection (self , pipeline , num_examples ):
70
- """Generate examples as dicts."""
71
- examples = (
72
- pipeline
73
- | beam .Create (range (num_examples ))
74
- | beam .Map (_gen_example )
75
- )
76
-
69
+ def _compute_metadata (self , examples , num_examples ):
77
70
self .info .metadata ["label_sum_%d" % num_examples ] = (
78
71
examples
79
72
| beam .Map (lambda x : x [1 ]["label" ])
@@ -83,6 +76,14 @@ def _build_pcollection(self, pipeline, num_examples):
83
76
| beam .Map (lambda x : x [1 ]["id" ])
84
77
| beam .CombineGlobally (beam .combiners .MeanCombineFn ()))
85
78
79
+ def _build_pcollection (self , pipeline , num_examples ):
80
+ """Generate examples as dicts."""
81
+ examples = (
82
+ pipeline
83
+ | beam .Create (range (num_examples ))
84
+ | beam .Map (_gen_example )
85
+ )
86
+ self ._compute_metadata (examples , num_examples )
86
87
return examples
87
88
88
89
@@ -94,6 +95,36 @@ def _gen_example(x):
94
95
})
95
96
96
97
98
+ class CommonPipelineDummyBeamDataset (DummyBeamDataset ):
99
+
100
+ def _split_generators (self , dl_manager , pipeline ):
101
+ del dl_manager
102
+
103
+ examples = (
104
+ pipeline
105
+ | beam .Create (range (1000 ))
106
+ | beam .Map (_gen_example )
107
+ )
108
+
109
+ return [
110
+ splits_lib .SplitGenerator (
111
+ name = splits_lib .Split .TRAIN ,
112
+ gen_kwargs = dict (examples = examples , num_examples = 1000 ),
113
+ ),
114
+ splits_lib .SplitGenerator (
115
+ name = splits_lib .Split .TEST ,
116
+ gen_kwargs = dict (examples = examples , num_examples = 725 ),
117
+ ),
118
+ ]
119
+
120
+ def _build_pcollection (self , pipeline , examples , num_examples ):
121
+ """Generate examples as dicts."""
122
+ del pipeline
123
+ examples |= beam .Filter (lambda x : x [0 ] < num_examples )
124
+ self ._compute_metadata (examples , num_examples )
125
+ return examples
126
+
127
+
97
128
class FaultyS3DummyBeamDataset (DummyBeamDataset ):
98
129
99
130
VERSION = utils .Version ("1.0.0" )
@@ -107,24 +138,24 @@ def test_download_prepare_raise(self):
107
138
with self .assertRaisesWithPredicateMatch (ValueError , "no Beam Runner" ):
108
139
builder .download_and_prepare ()
109
140
110
- def _assertBeamGeneration (self , dl_config ):
141
+ def _assertBeamGeneration (self , dl_config , dataset_cls , dataset_name ):
111
142
with testing .tmp_dir (self .get_temp_dir ()) as tmp_dir :
112
- builder = DummyBeamDataset (data_dir = tmp_dir )
143
+ builder = dataset_cls (data_dir = tmp_dir )
113
144
builder .download_and_prepare (download_config = dl_config )
114
145
115
- data_dir = os .path .join (tmp_dir , "dummy_beam_dataset" , "1.0.0" )
146
+ data_dir = os .path .join (tmp_dir , dataset_name , "1.0.0" )
116
147
self .assertEqual (data_dir , builder ._data_dir )
117
148
118
149
# Check number of shards
119
150
self ._assertShards (
120
151
data_dir ,
121
- pattern = "dummy_beam_dataset -test.tfrecord-{:05}-of-{:05}" ,
152
+ pattern = "%s -test.tfrecord-{:05}-of-{:05}" % dataset_name ,
122
153
# Liquid sharding is not guaranteed to always use the same number.
123
154
num_shards = builder .info .splits ["test" ].num_shards ,
124
155
)
125
156
self ._assertShards (
126
157
data_dir ,
127
- pattern = "dummy_beam_dataset -train.tfrecord-{:05}-of-{:05}" ,
158
+ pattern = "%s -train.tfrecord-{:05}-of-{:05}" % dataset_name ,
128
159
num_shards = 1 ,
129
160
)
130
161
@@ -177,7 +208,11 @@ def test_download_prepare(self):
177
208
dl_config = self ._get_dl_config_if_need_to_run ()
178
209
if not dl_config :
179
210
return
180
- self ._assertBeamGeneration (dl_config )
211
+ self ._assertBeamGeneration (
212
+ dl_config , DummyBeamDataset , "dummy_beam_dataset" )
213
+ self ._assertBeamGeneration (
214
+ dl_config , CommonPipelineDummyBeamDataset ,
215
+ "common_pipeline_dummy_beam_dataset" )
181
216
182
217
183
218
if __name__ == "__main__" :
0 commit comments