|
13 | 13 | SecretUnsignedInteger, UnsignedInteger)
|
14 | 14 |
|
15 | 15 | from nada_numpy.context import UnsafeArithmeticSession
|
16 |
| -from nada_numpy.nada_typing import (NadaBoolean, NadaCleartextType, |
17 |
| - NadaInteger, NadaRational, |
18 |
| - NadaUnsignedInteger) |
| 16 | +from nada_numpy.nada_typing import (AnyNadaType, NadaBoolean, |
| 17 | + NadaCleartextType, NadaInteger, |
| 18 | + NadaRational, NadaUnsignedInteger) |
19 | 19 | from nada_numpy.types import (Rational, SecretRational, fxp_abs, get_log_scale,
|
20 | 20 | public_rational, rational, secret_rational, sign)
|
21 | 21 | from nada_numpy.utils import copy_metadata
|
@@ -390,6 +390,119 @@ def __imatmul__(self, other: Any) -> "NadaArray":
|
390 | 390 | """
|
391 | 391 | return self.matmul(other)
|
392 | 392 |
|
| 393 | + def __comparison_operator( |
| 394 | + self, value: Union["NadaArray", "AnyNadaType", np.ndarray], operator: Callable |
| 395 | + ) -> "NadaArray": |
| 396 | + """ |
| 397 | + Perform element-wise comparison with broadcasting. |
| 398 | +
|
| 399 | + NOTE: Specially for __eq__ and __ne__ operators, the result expected is bool. |
| 400 | + If we don't define this method, the result will be a NadaArray with bool outputs. |
| 401 | +
|
| 402 | + Args: |
| 403 | + value (Any): The object to compare. |
| 404 | + operator (str): The comparison operator. |
| 405 | +
|
| 406 | + Returns: |
| 407 | + NadaArray: A new NadaArray representing the element-wise comparison result. |
| 408 | + """ |
| 409 | + if isinstance(value, NadaArray): |
| 410 | + value = value.inner |
| 411 | + if isinstance( |
| 412 | + value, |
| 413 | + ( |
| 414 | + SecretInteger, |
| 415 | + Integer, |
| 416 | + SecretUnsignedInteger, |
| 417 | + UnsignedInteger, |
| 418 | + SecretRational, |
| 419 | + Rational, |
| 420 | + ), |
| 421 | + ): |
| 422 | + return self.apply(lambda x: operator(x, value)) |
| 423 | + |
| 424 | + if isinstance(value, np.ndarray): |
| 425 | + if len(self.inner) != len(value): |
| 426 | + raise ValueError("Arrays must have the same length") |
| 427 | + return NadaArray( |
| 428 | + np.array([operator(x, y) for x, y in zip(self.inner, value)]) |
| 429 | + ) |
| 430 | + |
| 431 | + raise ValueError(f"Unsupported type: {type(value)}") |
| 432 | + |
| 433 | + def __eq__(self, value: Any) -> "NadaArray": # type: ignore |
| 434 | + """ |
| 435 | + Perform equality comparison with broadcasting. |
| 436 | +
|
| 437 | + Args: |
| 438 | + value (object): The object to compare. |
| 439 | +
|
| 440 | + Returns: |
| 441 | + NadaArray: A boolean representing the element-wise equality comparison result. |
| 442 | + """ |
| 443 | + return self.__comparison_operator(value, lambda x, y: x == y) |
| 444 | + |
| 445 | + def __ne__(self, value: Any) -> "NadaArray": # type: ignore |
| 446 | + """ |
| 447 | + Perform inequality comparison with broadcasting. |
| 448 | +
|
| 449 | + Args: |
| 450 | + value (object): The object to compare. |
| 451 | +
|
| 452 | + Returns: |
| 453 | + NadaArray: A boolean array representing the element-wise inequality comparison result. |
| 454 | + """ |
| 455 | + return self.__comparison_operator(value, lambda x, y: ~(x == y)) |
| 456 | + |
| 457 | + def __lt__(self, value: Any) -> "NadaArray": |
| 458 | + """ |
| 459 | + Perform less than comparison with broadcasting. |
| 460 | +
|
| 461 | + Args: |
| 462 | + value (object): The object to compare. |
| 463 | +
|
| 464 | + Returns: |
| 465 | + NadaArray: A boolean array representing the element-wise less than comparison result. |
| 466 | + """ |
| 467 | + return self.__comparison_operator(value, lambda x, y: x < y) |
| 468 | + |
| 469 | + def __le__(self, value: Any) -> "NadaArray": |
| 470 | + """ |
| 471 | + Perform less than or equal comparison with broadcasting. |
| 472 | +
|
| 473 | + Args: |
| 474 | + value (object): The object to compare. |
| 475 | +
|
| 476 | + Returns: |
| 477 | + NadaArray: A boolean array representing |
| 478 | + the element-wise less or equal thancomparison result. |
| 479 | + """ |
| 480 | + return self.__comparison_operator(value, lambda x, y: x <= y) |
| 481 | + |
| 482 | + def __gt__(self, value: Any) -> "NadaArray": |
| 483 | + """ |
| 484 | + Perform greater than comparison with broadcasting. |
| 485 | +
|
| 486 | + Args: |
| 487 | + value (object): The object to compare. |
| 488 | +
|
| 489 | + Returns: |
| 490 | + NadaArray: A boolean array representing the element-wise greater than comparison result. |
| 491 | + """ |
| 492 | + return self.__comparison_operator(value, lambda x, y: x > y) |
| 493 | + |
| 494 | + def __ge__(self, value: Any) -> "NadaArray": |
| 495 | + """ |
| 496 | + Perform greater than or equal comparison with broadcasting. |
| 497 | +
|
| 498 | + Args: |
| 499 | + value (object): The object to compare. |
| 500 | +
|
| 501 | + Returns: |
| 502 | + NadaArray: A boolean representing the element-wise greater or equal than comparison. |
| 503 | + """ |
| 504 | + return self.__comparison_operator(value, lambda x, y: x >= y) |
| 505 | + |
393 | 506 | def dot(self, other: "NadaArray") -> "NadaArray":
|
394 | 507 | """
|
395 | 508 | Compute the dot product between two NadaArray objects.
|
|
0 commit comments