|
11 | 11 | import json |
12 | 12 | import logging |
13 | 13 | import math |
| 14 | +import re |
14 | 15 | import time |
15 | 16 | import traceback |
16 | 17 | from dataclasses import dataclass |
@@ -88,12 +89,10 @@ def add_train(self, samples, lr, loss_avg, stddev_avg, perf_gpu=0.0, perf_mem=0. |
88 | 89 | log_vals += [lr] |
89 | 90 |
|
90 | 91 | for i_obs, st in enumerate(self.cf.streams): |
91 | | - st_name = _clean_name(st["name"]) |
92 | 92 | for j, (lf_name, _) in enumerate(self.cf.loss_fcts): |
93 | | - lf_name = _clean_name(lf_name) |
94 | | - metrics[_key_loss(st_name, lf_name)] = loss_avg[j, i_obs] |
| 93 | + metrics[_key_loss(st["name"], lf_name)] = loss_avg[j, i_obs] |
95 | 94 | if len(stddev_avg) > 0: |
96 | | - metrics[_key_stddev(st_name)] = stddev_avg[i_obs] |
| 95 | + metrics[_key_stddev(st["name"])] = stddev_avg[i_obs] |
97 | 96 | log_vals += [loss_avg[j, i_obs]] |
98 | 97 | if len(stddev_avg) > 0: |
99 | 98 | for i_obs, _rt in enumerate(self.cf.streams): |
@@ -377,17 +376,33 @@ def clean_df(df, columns: list[str] | None): |
377 | 376 | return df |
378 | 377 |
|
379 | 378 |
|
380 | | -def _clean_name(n: str) -> str: |
381 | | - """Cleans the stream name to only retain alphanumeric characters""" |
382 | | - return "".join([c for c in n if c.isalnum()]) |
| 379 | +def clean_name(s, regex=r"[a-zA-Z0-9]+"): |
| 380 | + """ |
| 381 | + Convert a string to camelCase. |
| 382 | + Characters to omit are all not in given regular expression. |
| 383 | +
|
| 384 | + :param s: Input string |
| 385 | + :param regex: Regular expression defining valid characters in words |
| 386 | + (defaults to alpha-numeric) |
| 387 | + :return: CamelCase version of the input string |
| 388 | + """ |
| 389 | + re_pattern = re.compile(regex) |
| 390 | + words = re_pattern.findall(s) |
| 391 | + |
| 392 | + if not words: |
| 393 | + return "" |
| 394 | + |
| 395 | + camel_cased = "".join(word.capitalize() for word in words) |
| 396 | + |
| 397 | + return camel_cased |
383 | 398 |
|
384 | 399 |
|
385 | 400 | def _key_loss(st_name: str, lf_name: str) -> str: |
386 | | - st_name = _clean_name(st_name) |
387 | | - lf_name = _clean_name(lf_name) |
| 401 | + st_name = clean_name(st_name) |
| 402 | + lf_name = clean_name(lf_name) |
388 | 403 | return f"stream.{st_name}.loss_{lf_name}.loss_avg" |
389 | 404 |
|
390 | 405 |
|
391 | 406 | def _key_stddev(st_name: str) -> str: |
392 | | - st_name = _clean_name(st_name) |
| 407 | + st_name = clean_name(st_name) |
393 | 408 | return f"stream.{st_name}.stddev_avg" |
0 commit comments