Skip to content

Commit cc43efa

Browse files
[Fix] Fix eval (#931)
* update develop mkdocs * allow alias for mike * fix and refine eval.py * fix
1 parent 70e5e8b commit cc43efa

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

ppsci/solver/eval.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ def _eval_by_dataset(
167167
for metric_name, metric_func in _validator.metric.items():
168168
# NOTE: compute metric with entire output and label
169169
metric_dict = metric_func(all_output, all_label)
170+
assert metric_name not in metric_dict_group, (
171+
f"Metric name({metric_name}) already exists, please ensure all metric "
172+
"names are unique over all validators."
173+
)
170174
metric_dict_group[metric_name] = {
171175
k: float(v) for k, v in metric_dict.items()
172176
}
@@ -215,7 +219,6 @@ def _eval_by_batch(
215219
num_samples = _get_dataset_length(_validator.data_loader)
216220

217221
loss_dict = misc.Prettydefaultdict(float)
218-
metric_dict_group: Dict[str, Dict[str, float]] = misc.PrettyOrderedDict()
219222
reader_tic = time.perf_counter()
220223
batch_tic = time.perf_counter()
221224
for iter_id, batch in enumerate(_validator.data_loader, start=1):
@@ -251,9 +254,12 @@ def _eval_by_batch(
251254

252255
# collect batch metric
253256
for metric_name, metric_func in _validator.metric.items():
257+
assert metric_name not in metric_dict_group, (
258+
f"Metric name({metric_name}) already exists, please ensure all metric "
259+
"names are unique over all validators."
260+
)
261+
metric_dict_group[metric_name] = misc.Prettydefaultdict(list)
254262
metric_dict = metric_func(output_dict, label_dict)
255-
if metric_name not in metric_dict_group:
256-
metric_dict_group[metric_name] = misc.Prettydefaultdict(list)
257263
for var_name, metric_value in metric_dict.items():
258264
metric_dict_group[metric_name][var_name].append(
259265
metric_value
@@ -284,9 +290,9 @@ def _eval_by_batch(
284290
# concatenate all metric and discard metric of padded sample(s)
285291
for metric_name, metric_dict in metric_dict_group.items():
286292
for var_name, metric_value in metric_dict.items():
287-
# NOTE: concat all metric(scalars) into metric vector
293+
# NOTE: concat single metric(scalar) list into metric vector
288294
metric_value = paddle.concat(metric_value)[:num_samples]
289-
# NOTE: compute metric via averaging metric vector,
295+
# NOTE: compute metric via averaging metric over all samples,
290296
# this might be not general for certain evaluation case
291297
metric_value = float(metric_value.mean())
292298
metric_dict_group[metric_name][var_name] = metric_value

0 commit comments

Comments
 (0)