@@ -167,6 +167,10 @@ def _eval_by_dataset(
167
167
for metric_name , metric_func in _validator .metric .items ():
168
168
# NOTE: compute metric with entire output and label
169
169
metric_dict = metric_func (all_output , all_label )
170
+ assert metric_name not in metric_dict_group , (
171
+ f"Metric name({ metric_name } ) already exists, please ensure all metric "
172
+ "names are unique over all validators."
173
+ )
170
174
metric_dict_group [metric_name ] = {
171
175
k : float (v ) for k , v in metric_dict .items ()
172
176
}
@@ -215,7 +219,6 @@ def _eval_by_batch(
215
219
num_samples = _get_dataset_length (_validator .data_loader )
216
220
217
221
loss_dict = misc .Prettydefaultdict (float )
218
- metric_dict_group : Dict [str , Dict [str , float ]] = misc .PrettyOrderedDict ()
219
222
reader_tic = time .perf_counter ()
220
223
batch_tic = time .perf_counter ()
221
224
for iter_id , batch in enumerate (_validator .data_loader , start = 1 ):
@@ -251,9 +254,12 @@ def _eval_by_batch(
251
254
252
255
# collect batch metric
253
256
for metric_name , metric_func in _validator .metric .items ():
257
+ assert metric_name not in metric_dict_group , (
258
+ f"Metric name({ metric_name } ) already exists, please ensure all metric "
259
+ "names are unique over all validators."
260
+ )
261
+ metric_dict_group [metric_name ] = misc .Prettydefaultdict (list )
254
262
metric_dict = metric_func (output_dict , label_dict )
255
- if metric_name not in metric_dict_group :
256
- metric_dict_group [metric_name ] = misc .Prettydefaultdict (list )
257
263
for var_name , metric_value in metric_dict .items ():
258
264
metric_dict_group [metric_name ][var_name ].append (
259
265
metric_value
@@ -284,9 +290,9 @@ def _eval_by_batch(
284
290
# concatenate all metric and discard metric of padded sample(s)
285
291
for metric_name , metric_dict in metric_dict_group .items ():
286
292
for var_name , metric_value in metric_dict .items ():
287
- # NOTE: concat all metric(scalars) into metric vector
293
+ # NOTE: concat single metric(scalar) list into metric vector
288
294
metric_value = paddle .concat (metric_value )[:num_samples ]
289
- # NOTE: compute metric via averaging metric vector ,
295
+ # NOTE: compute metric via averaging metric over all samples ,
290
296
# this might be not general for certain evaluation case
291
297
metric_value = float (metric_value .mean ())
292
298
metric_dict_group [metric_name ][var_name ] = metric_value
0 commit comments