Skip to content

Commit 4a8bd49

Browse files
Updated to camel case. (ecmwf#445)
* Updated to camel case. * Fixed formatting.
1 parent ee5f709 commit 4a8bd49

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

src/weathergen/utils/train_logger.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import json
1212
import logging
1313
import math
14+
import re
1415
import time
1516
import traceback
1617
from dataclasses import dataclass
@@ -88,12 +89,10 @@ def add_train(self, samples, lr, loss_avg, stddev_avg, perf_gpu=0.0, perf_mem=0.
8889
log_vals += [lr]
8990

9091
for i_obs, st in enumerate(self.cf.streams):
91-
st_name = _clean_name(st["name"])
9292
for j, (lf_name, _) in enumerate(self.cf.loss_fcts):
93-
lf_name = _clean_name(lf_name)
94-
metrics[_key_loss(st_name, lf_name)] = loss_avg[j, i_obs]
93+
metrics[_key_loss(st["name"], lf_name)] = loss_avg[j, i_obs]
9594
if len(stddev_avg) > 0:
96-
metrics[_key_stddev(st_name)] = stddev_avg[i_obs]
95+
metrics[_key_stddev(st["name"])] = stddev_avg[i_obs]
9796
log_vals += [loss_avg[j, i_obs]]
9897
if len(stddev_avg) > 0:
9998
for i_obs, _rt in enumerate(self.cf.streams):
@@ -377,17 +376,33 @@ def clean_df(df, columns: list[str] | None):
377376
return df
378377

379378

380-
def _clean_name(n: str) -> str:
381-
"""Cleans the stream name to only retain alphanumeric characters"""
382-
return "".join([c for c in n if c.isalnum()])
379+
def clean_name(s, regex=r"[a-zA-Z0-9]+"):
380+
"""
381+
Convert a string to camelCase.
382+
Characters to omit are all not in given regular expression.
383+
384+
:param s: Input string
385+
:param regex: Regular expression defining valid characters in words
386+
(defaults to alpha-numeric)
387+
:return: CamelCase version of the input string
388+
"""
389+
re_pattern = re.compile(regex)
390+
words = re_pattern.findall(s)
391+
392+
if not words:
393+
return ""
394+
395+
camel_cased = "".join(word.capitalize() for word in words)
396+
397+
return camel_cased
383398

384399

385400
def _key_loss(st_name: str, lf_name: str) -> str:
386-
st_name = _clean_name(st_name)
387-
lf_name = _clean_name(lf_name)
401+
st_name = clean_name(st_name)
402+
lf_name = clean_name(lf_name)
388403
return f"stream.{st_name}.loss_{lf_name}.loss_avg"
389404

390405

391406
def _key_stddev(st_name: str) -> str:
392-
st_name = _clean_name(st_name)
407+
st_name = clean_name(st_name)
393408
return f"stream.{st_name}.stddev_avg"

0 commit comments

Comments
 (0)