Skip to content

Commit 14c2982

Browse files
committed
[Deprecation] Deprecated str input.
The `comparison_classes` parameter of the `EvaluationMetrics.embedding_concordance` method will no longer accept `str` input. Typing has been changed, and a deprecation warning is raised.
1 parent 4fc2ce2 commit 14c2982

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

CytofDR/evaluation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from annoy import AnnoyIndex
1111
import itertools
1212
from typing import Optional, Any, Union, List, Tuple, Callable
13+
import warnings
1314

1415

1516
class EvaluationMetrics():
@@ -448,7 +449,7 @@ def embedding_concordance(embedding: "np.ndarray",
448449
labels_embedding: "np.ndarray",
449450
comparison_file: Union["np.ndarray", List["np.ndarray"]],
450451
comparison_labels: Union["np.ndarray", List["np.ndarray"]],
451-
comparison_classes: Optional[Union[str, List[str]]]=None,
452+
comparison_classes: Optional[List[str]]=None,
452453
method: str = "emd"
453454
) -> Union[float, str]:
454455
"""Concordance between two embeddings.
@@ -470,12 +471,16 @@ def embedding_concordance(embedding: "np.ndarray",
470471
:param labels_embedding: Labels for all obervations in the embedding.
471472
:param comparison_file: The second embedding.
472473
:param comparison_labels: The labels for all observations in the comparison embedding.
473-
:param comparison_classes: Which classes in labels to compare. If ``None``, all overlapping labels used, optional
474+
:param comparison_classes: Which classes in labels to compare. At least two classes need
475+
to be provided for this to work; otherwise, `NA` will be returned. If ``None``, all
476+
overlapping labels used, optional
474477
:param method: "emd" or "cluster_distance", defaults to "emd"
475478
476479
:return: The score or "NA"
477480
478481
.. Note:: When there is no overlapping labels, "NA" is automatically returned as ``str``.
482+
.. Deprecation Notice:: Passing in `str` for the `comparison_classes` parameter is deprecated
483+
and will be removed in futrue versions.
479484
"""
480485

481486
if not isinstance(comparison_file, list):
@@ -484,6 +489,7 @@ def embedding_concordance(embedding: "np.ndarray",
484489
comparison_labels = [comparison_labels]
485490
if not isinstance(comparison_classes, list) and comparison_classes is not None:
486491
comparison_classes = [comparison_classes]
492+
warnings.warn("Passing in a non-list parameter is deprecated. Use a list instead.", DeprecationWarning, stacklevel=2)
487493

488494
method = method.lower()
489495

tests/test_evaluation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ def test_embedding_concordance_comparison_labels(self):
167167
assert score >= 0
168168

169169

170+
def test_embedding_concordance_comparison_classes_type(self):
171+
with pytest.warns(DeprecationWarning):
172+
score: Union[float, str] = evaluation.EvaluationMetrics.embedding_concordance(self.embedding, self.embedding_labels,
173+
self.comparison_data, self.comparison_labels, "a",
174+
method="emd")
175+
assert isinstance(score, str)
176+
assert score == "NA"
177+
178+
170179
def test_embedding_concordance_na(self):
171180
score: Union[float, str] = evaluation.EvaluationMetrics.embedding_concordance(self.embedding, self.embedding_labels,
172181
self.comparison_data, self.comparison_labels, ["e"])

0 commit comments

Comments
 (0)