1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import warnings
1615
1716import pytest
1817import torch
2524 MulticlassRecall ,
2625)
2726from torchmetrics .regression import MeanAbsoluteError , MeanSquaredError
28- from torchmetrics .utilities .imports import _TORCHMETRICS_GREATER_EQUAL_1_6
2927from torchmetrics .wrappers import ClasswiseWrapper , MetricTracker , MultioutputWrapper
3028from unittests ._helpers import seed_all
3129
@@ -154,8 +152,8 @@ def test_tracker(base_metric, metric_input, maximize):
154152@pytest .mark .parametrize (
155153 "base_metric" ,
156154 [
157- MulticlassConfusionMatrix (3 ),
158- MetricCollection ([MulticlassConfusionMatrix (3 ), MulticlassAccuracy (3 )]),
155+ pytest . param ( MulticlassConfusionMatrix (3 ), id = "Multiclass-confusion-matrix" ),
156+ pytest . param ( MetricCollection ([MulticlassConfusionMatrix (3 ), MulticlassAccuracy (3 )]), id = "Metric-collection" ),
159157 ],
160158)
161159def test_best_metric_for_not_well_defined_metric_collection (base_metric ):
@@ -165,7 +163,7 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric):
165163 warning and return None.
166164
167165 """
168- tracker = MetricTracker (base_metric )
166+ tracker = MetricTracker (base_metric , maximize = True )
169167 for _ in range (3 ):
170168 tracker .increment ()
171169 for _ in range (5 ):
@@ -207,7 +205,7 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric):
207205)
208206def test_metric_tracker_and_collection_multioutput (input_to_tracker , assert_type ):
209207 """Check that MetricTracker support wrapper inputs and nested structures."""
210- tracker = MetricTracker (input_to_tracker )
208+ tracker = MetricTracker (input_to_tracker , maximize = False )
211209 for _ in range (5 ):
212210 tracker .increment ()
213211 for _ in range (5 ):
@@ -226,22 +224,6 @@ def test_metric_tracker_and_collection_multioutput(input_to_tracker, assert_type
226224 assert which_epoch is None
227225
228226
229- def test_tracker_futurewarning ():
230- """Check that future warning is raised for the maximize argument.
231-
232- Also to make sure that we remove it in future versions of TM.
233-
234- """
235- if _TORCHMETRICS_GREATER_EQUAL_1_6 :
236- # Check that for future versions that we remove the warning
237- with warnings .catch_warnings ():
238- warnings .simplefilter ("error" )
239- MetricTracker (MeanSquaredError (), maximize = True )
240- else :
241- with pytest .warns (FutureWarning , match = "The default value for `maximize` will be changed from `True` to.*" ):
242- MetricTracker (MeanSquaredError (), maximize = True )
243-
244-
245227@pytest .mark .parametrize (
246228 "base_metric" ,
247229 [
0 commit comments