Skip to content

Commit d95eb5d

Browse files
authored
Feat: Add array comparison methods. (#60)
* feat: Improved tests for broadcasting * feat: added support for comparison operators over nada_arrays * chore: bump version for new release
1 parent dd112a0 commit d95eb5d

17 files changed

+651
-307
lines changed

nada_numpy/array.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
SecretUnsignedInteger, UnsignedInteger)
1414

1515
from nada_numpy.context import UnsafeArithmeticSession
16-
from nada_numpy.nada_typing import (NadaBoolean, NadaCleartextType,
17-
NadaInteger, NadaRational,
18-
NadaUnsignedInteger)
16+
from nada_numpy.nada_typing import (AnyNadaType, NadaBoolean,
17+
NadaCleartextType, NadaInteger,
18+
NadaRational, NadaUnsignedInteger)
1919
from nada_numpy.types import (Rational, SecretRational, fxp_abs, get_log_scale,
2020
public_rational, rational, secret_rational, sign)
2121
from nada_numpy.utils import copy_metadata
@@ -390,6 +390,119 @@ def __imatmul__(self, other: Any) -> "NadaArray":
390390
"""
391391
return self.matmul(other)
392392

393+
def __comparison_operator(
394+
self, value: Union["NadaArray", "AnyNadaType", np.ndarray], operator: Callable
395+
) -> "NadaArray":
396+
"""
397+
Perform element-wise comparison with broadcasting.
398+
399+
NOTE: Specially for __eq__ and __ne__ operators, the result expected is bool.
400+
If we don't define this method, the result will be a NadaArray with bool outputs.
401+
402+
Args:
403+
value (Any): The object to compare.
404+
operator (str): The comparison operator.
405+
406+
Returns:
407+
NadaArray: A new NadaArray representing the element-wise comparison result.
408+
"""
409+
if isinstance(value, NadaArray):
410+
value = value.inner
411+
if isinstance(
412+
value,
413+
(
414+
SecretInteger,
415+
Integer,
416+
SecretUnsignedInteger,
417+
UnsignedInteger,
418+
SecretRational,
419+
Rational,
420+
),
421+
):
422+
return self.apply(lambda x: operator(x, value))
423+
424+
if isinstance(value, np.ndarray):
425+
if len(self.inner) != len(value):
426+
raise ValueError("Arrays must have the same length")
427+
return NadaArray(
428+
np.array([operator(x, y) for x, y in zip(self.inner, value)])
429+
)
430+
431+
raise ValueError(f"Unsupported type: {type(value)}")
432+
433+
def __eq__(self, value: Any) -> "NadaArray": # type: ignore
434+
"""
435+
Perform equality comparison with broadcasting.
436+
437+
Args:
438+
value (object): The object to compare.
439+
440+
Returns:
441+
NadaArray: A boolean representing the element-wise equality comparison result.
442+
"""
443+
return self.__comparison_operator(value, lambda x, y: x == y)
444+
445+
def __ne__(self, value: Any) -> "NadaArray": # type: ignore
446+
"""
447+
Perform inequality comparison with broadcasting.
448+
449+
Args:
450+
value (object): The object to compare.
451+
452+
Returns:
453+
NadaArray: A boolean array representing the element-wise inequality comparison result.
454+
"""
455+
return self.__comparison_operator(value, lambda x, y: ~(x == y))
456+
457+
def __lt__(self, value: Any) -> "NadaArray":
458+
"""
459+
Perform less than comparison with broadcasting.
460+
461+
Args:
462+
value (object): The object to compare.
463+
464+
Returns:
465+
NadaArray: A boolean array representing the element-wise less than comparison result.
466+
"""
467+
return self.__comparison_operator(value, lambda x, y: x < y)
468+
469+
def __le__(self, value: Any) -> "NadaArray":
470+
"""
471+
Perform less than or equal comparison with broadcasting.
472+
473+
Args:
474+
value (object): The object to compare.
475+
476+
Returns:
477+
NadaArray: A boolean array representing
478+
the element-wise less or equal thancomparison result.
479+
"""
480+
return self.__comparison_operator(value, lambda x, y: x <= y)
481+
482+
def __gt__(self, value: Any) -> "NadaArray":
483+
"""
484+
Perform greater than comparison with broadcasting.
485+
486+
Args:
487+
value (object): The object to compare.
488+
489+
Returns:
490+
NadaArray: A boolean array representing the element-wise greater than comparison result.
491+
"""
492+
return self.__comparison_operator(value, lambda x, y: x > y)
493+
494+
def __ge__(self, value: Any) -> "NadaArray":
495+
"""
496+
Perform greater than or equal comparison with broadcasting.
497+
498+
Args:
499+
value (object): The object to compare.
500+
501+
Returns:
502+
NadaArray: A boolean representing the element-wise greater or equal than comparison.
503+
"""
504+
return self.__comparison_operator(value, lambda x, y: x >= y)
505+
393506
def dot(self, other: "NadaArray") -> "NadaArray":
394507
"""
395508
Compute the dot product between two NadaArray objects.

nada_numpy/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2929,7 +2929,6 @@ def _chebyshev_polynomials(x: _NadaRational, terms: int) -> np.ndarray:
29292929

29302930
# return polynomials
29312931

2932-
29332932
polynomials = [x]
29342933
y = rational(4) * x * x - rational(2)
29352934
z = y - rational(1)

0 commit comments

Comments
 (0)