10
10
from annoy import AnnoyIndex
11
11
import itertools
12
12
from typing import Optional , Any , Union , List , Tuple , Callable
13
+ import warnings
13
14
14
15
15
16
class EvaluationMetrics ():
@@ -448,7 +449,7 @@ def embedding_concordance(embedding: "np.ndarray",
448
449
labels_embedding : "np.ndarray" ,
449
450
comparison_file : Union ["np.ndarray" , List ["np.ndarray" ]],
450
451
comparison_labels : Union ["np.ndarray" , List ["np.ndarray" ]],
451
- comparison_classes : Optional [Union [ str , List [str ] ]]= None ,
452
+ comparison_classes : Optional [List [str ]]= None ,
452
453
method : str = "emd"
453
454
) -> Union [float , str ]:
454
455
"""Concordance between two embeddings.
@@ -470,12 +471,16 @@ def embedding_concordance(embedding: "np.ndarray",
470
471
:param labels_embedding: Labels for all obervations in the embedding.
471
472
:param comparison_file: The second embedding.
472
473
:param comparison_labels: The labels for all observations in the comparison embedding.
473
- :param comparison_classes: Which classes in labels to compare. If ``None``, all overlapping labels used, optional
474
+ :param comparison_classes: Which classes in labels to compare. At least two classes need
475
+ to be provided for this to work; otherwise, `NA` will be returned. If ``None``, all
476
+ overlapping labels used, optional
474
477
:param method: "emd" or "cluster_distance", defaults to "emd"
475
478
476
479
:return: The score or "NA"
477
480
478
481
.. Note:: When there is no overlapping labels, "NA" is automatically returned as ``str``.
482
+ .. Deprecation Notice:: Passing in `str` for the `comparison_classes` parameter is deprecated
483
+ and will be removed in futrue versions.
479
484
"""
480
485
481
486
if not isinstance (comparison_file , list ):
@@ -484,6 +489,7 @@ def embedding_concordance(embedding: "np.ndarray",
484
489
comparison_labels = [comparison_labels ]
485
490
if not isinstance (comparison_classes , list ) and comparison_classes is not None :
486
491
comparison_classes = [comparison_classes ]
492
+ warnings .warn ("Passing in a non-list parameter is deprecated. Use a list instead." , DeprecationWarning , stacklevel = 2 )
487
493
488
494
method = method .lower ()
489
495
0 commit comments