1
1
import math
2
2
import datetime
3
- from deepdiff . base import BaseProtocol
3
+ from typing import TYPE_CHECKING , Callable , Protocol , Any
4
4
from deepdiff .deephash import DeepHash
5
5
from deepdiff .helper import (
6
6
DELTA_VIEW , numbers , strings , add_to_frozen_set , not_found , only_numbers , np , np_float64 , time_to_seconds ,
7
7
cartesian_product_numpy , np_ndarray , np_array_factory , get_homogeneous_numpy_compatible_type_of_seq , dict_ ,
8
8
CannotCompare )
9
9
from collections .abc import Mapping , Iterable
10
10
11
+ if TYPE_CHECKING :
12
+ from deepdiff .diff import DeepDiffProtocol
11
13
12
- DISTANCE_CALCS_NEEDS_CACHE = "Distance calculation can not happen once the cache is purged. Try with _cache='keep'"
14
+ class DistanceProtocol (DeepDiffProtocol , Protocol ):
15
+ hashes : dict
16
+ deephash_parameters : dict
17
+ iterable_compare_func : Callable | None
18
+ math_epsilon : float
19
+ cutoff_distance_for_pairs : float
20
+
21
+ def __get_item_rough_length (self , item , parent :str = "root" ) -> float :
22
+ ...
13
23
24
+ def _to_delta_dict (
25
+ self ,
26
+ directed : bool = True ,
27
+ report_repetition_required : bool = True ,
28
+ always_include_values : bool = False ,
29
+ ) -> dict :
30
+ ...
14
31
32
+ def __calculate_item_deephash (self , item : Any ) -> None :
33
+ ...
15
34
16
35
17
- class DistanceMixin (BaseProtocol ):
18
36
19
- def _get_rough_distance (self ):
37
+ DISTANCE_CALCS_NEEDS_CACHE = "Distance calculation can not happen once the cache is purged. Try with _cache='keep'"
38
+
39
+
40
+ class DistanceMixin :
41
+
42
+ def _get_rough_distance (self : "DistanceProtocol" ):
20
43
"""
21
44
Gives a numeric value for the distance of t1 and t2 based on how many operations are needed to convert
22
45
one to the other.
@@ -51,7 +74,7 @@ def _get_rough_distance(self):
51
74
52
75
return diff_length / (t1_len + t2_len )
53
76
54
- def __get_item_rough_length (self , item , parent = 'root' ):
77
+ def __get_item_rough_length (self : "DistanceProtocol" , item , parent = 'root' ):
55
78
"""
56
79
Get the rough length of an item.
57
80
It is used as a part of calculating the rough distance between objects.
@@ -69,7 +92,7 @@ def __get_item_rough_length(self, item, parent='root'):
69
92
length = DeepHash .get_key (self .hashes , key = item , default = None , extract_index = 1 )
70
93
return length
71
94
72
- def __calculate_item_deephash (self , item ) :
95
+ def __calculate_item_deephash (self : "DistanceProtocol" , item : Any ) -> None :
73
96
DeepHash (
74
97
item ,
75
98
hashes = self .hashes ,
@@ -79,8 +102,7 @@ def __calculate_item_deephash(self, item):
79
102
)
80
103
81
104
def _precalculate_distance_by_custom_compare_func (
82
- self , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
83
-
105
+ self : "DistanceProtocol" , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
84
106
pre_calced_distances = dict_ ()
85
107
for added_hash in hashes_added :
86
108
for removed_hash in hashes_removed :
@@ -99,7 +121,7 @@ def _precalculate_distance_by_custom_compare_func(
99
121
return pre_calced_distances
100
122
101
123
def _precalculate_numpy_arrays_distance (
102
- self , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
124
+ self : "DistanceProtocol" , hashes_added , hashes_removed , t1_hashtable , t2_hashtable , _original_type ):
103
125
104
126
# We only want to deal with 1D arrays.
105
127
if isinstance (t2_hashtable [next (iter (hashes_added ))].item , (np_ndarray , list )):
@@ -203,7 +225,7 @@ def _get_numbers_distance(num1, num2, max_=1, use_log_scale=False, log_scale_sim
203
225
return 0
204
226
if use_log_scale :
205
227
distance = logarithmic_distance (num1 , num2 )
206
- if distance < logarithmic_distance :
228
+ if distance < 0 :
207
229
return 0
208
230
return distance
209
231
if not isinstance (num1 , float ):
@@ -246,7 +268,7 @@ def numpy_apply_log_keep_sign(array, offset=MATH_LOG_OFFSET):
246
268
return signed_log_values
247
269
248
270
249
- def logarithmic_similarity (a : numbers , b : numbers , threshold : float = 0.1 ):
271
+ def logarithmic_similarity (a : numbers , b : numbers , threshold : float = 0.1 ) -> float :
250
272
"""
251
273
A threshold of 0.1 translates to about 10.5% difference.
252
274
A threshold of 0.5 translates to about 65% difference.
@@ -255,7 +277,7 @@ def logarithmic_similarity(a: numbers, b: numbers, threshold: float=0.1):
255
277
return logarithmic_distance (a , b ) < threshold
256
278
257
279
258
- def logarithmic_distance (a : numbers , b : numbers ):
280
+ def logarithmic_distance (a : numbers , b : numbers ) -> float :
259
281
# Apply logarithm to the absolute values and consider the sign
260
282
a = float (a )
261
283
b = float (b )
0 commit comments