2424
2525from torchmetrics .metric import Metric
2626from torchmetrics .utilities import rank_zero_warn
27- from torchmetrics .utilities .data import _flatten_dict , allclose
27+ from torchmetrics .utilities .data import _flatten , _flatten_dict , allclose
2828from torchmetrics .utilities .imports import _MATPLOTLIB_AVAILABLE
2929from torchmetrics .utilities .plot import _AX_TYPE , _PLOT_OUT_TYPE , plot_single_or_multi_val
3030
@@ -90,7 +90,9 @@ class name as key for the output dict.
9090 due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric
9191 states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this
9292 reference and a copy of states are instead returned in this case (reference will be reestablished on the next
93- call to ``update``).
93+ call to ``update``). Do note that for the time being that if you are manually specifying compute groups in
94+ nested collections, these are not compatible with the compute groups of the parent collection and will be
95+ overridden.
9496
9597 .. important::
9698 Metric collections can be nested at initialization (see last example) but the output of the collection will
@@ -192,7 +194,6 @@ class name of the metric:
192194 """
193195
194196 _modules : dict [str , Metric ] # type: ignore[assignment]
195- _groups : Dict [int , List [str ]]
196197 __jit_unused_properties__ : ClassVar [list [str ]] = ["metric_state" ]
197198
198199 def __init__ (
@@ -210,7 +211,7 @@ def __init__(
210211 self ._enable_compute_groups = compute_groups
211212 self ._groups_checked : bool = False
212213 self ._state_is_copy : bool = False
213-
214+ self . _groups : Dict [ int , list [ str ]] = {}
214215 self .add_metrics (metrics , * additional_metrics )
215216
216217 @property
@@ -338,7 +339,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
338339 of just passed by reference
339340
340341 """
341- if not self ._state_is_copy :
342+ if not self ._state_is_copy and self . _groups_checked :
342343 for cg in self ._groups .values ():
343344 m0 = getattr (self , cg [0 ])
344345 for i in range (1 , len (cg )):
@@ -495,7 +496,6 @@ def add_metrics(
495496 "Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
496497 f" previous, but got { metrics } "
497498 )
498-
499499 self ._groups_checked = False
500500 if self ._enable_compute_groups :
501501 self ._init_compute_groups ()
@@ -518,9 +518,15 @@ def _init_compute_groups(self) -> None:
518518 f"Input { metric } in `compute_groups` argument does not match a metric in the collection."
519519 f" Please make sure that { self ._enable_compute_groups } matches { self .keys (keep_base = True )} "
520520 )
521+ # add metrics not specified in compute groups as their own group
522+ already_in_group = _flatten (self ._groups .values ()) # type: ignore
523+ counter = len (self ._groups )
524+ for k in self .keys (keep_base = True ):
525+ if k not in already_in_group :
526+ self ._groups [counter ] = [k ] # type: ignore
527+ counter += 1
521528 self ._groups_checked = True
522529 else :
523- # Initialize all metrics as their own compute group
524530 self ._groups = {i : [str (k )] for i , k in enumerate (self .keys (keep_base = True ))}
525531
526532 @property
0 commit comments