Skip to content

Commit bb1015f

Browse files
authored
fix: revert Numpy version 2.0.0 incompatibility with Prophet (#55)
* fix: revert Numpy version 2.0.0 incompatibility with Prophet
1 parent fbacdd2 commit bb1015f

7 files changed

+224
-204
lines changed

nada_numpy/array.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pylint:disable=too-many-lines
77

8-
from typing import Any, Callable, Optional, Sequence, Union, get_args
8+
from typing import Any, Callable, Optional, Sequence, Union, get_args, overload
99

1010
import numpy as np
1111
from nada_dsl import (Boolean, Input, Integer, Output, Party, PublicInteger,
@@ -870,6 +870,28 @@ def item(self, *args, **kwargs):
870870
return NadaArray(result)
871871
return result
872872

873+
@overload
874+
def itemset(self, value: Any): ...
875+
@overload
876+
def itemset(self, item: Any, value: Any): ...
877+
878+
# pylint:disable=missing-function-docstring
879+
@copy_metadata(np.ndarray.itemset)
880+
def itemset(self, *args, **kwargs):
881+
value = None
882+
if len(args) == 1:
883+
value = args[0]
884+
elif len(args) == 2:
885+
value = args[1]
886+
else:
887+
value = kwargs["value"]
888+
889+
_check_type_compatibility(value, self.dtype)
890+
result = self.inner.itemset(*args, **kwargs)
891+
if isinstance(result, np.ndarray):
892+
return NadaArray(result)
893+
return result
894+
873895
# pylint:disable=missing-function-docstring
874896
@copy_metadata(np.ndarray.prod)
875897
def prod(self, *args, **kwargs):

poetry.lock

Lines changed: 194 additions & 199 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
[tool.poetry]
22
name = "nada-numpy"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
description = "Nada-Numpy is a Python library designed for algebraic operations on NumPy-like array objects on top of Nada DSL and Nillion Network."
55
authors = ["José Cabrero-Holgueras <jose.cabrero@nillion.com>"]
66
readme = "README.md"
77

88
[tool.poetry.dependencies]
99
python = "^3.10"
10-
numpy = "^2.0.0"
10+
numpy = "^1.26.4"
1111
nada-dsl = "^0.4.0"
1212
py-nillion-client = "^0.4.0"
1313
nillion-python-helpers = "^0.2.3"

tests/nada-tests/src/supported_operations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def nada_main():
5656

5757
f = na.array([1], parties[0], "F", SecretInteger)
5858
f.fill(Integer(40))
59+
f.itemset(0, f.item(0) + Integer(2))
5960
with pytest.raises(Exception):
6061
f.itemset(0, na.rational(1))
6162
f = f.tolist()[0]

tests/nada-tests/src/supported_operations_return_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def nada_main():
8484
check_array(f)
8585
f.fill(Integer(40))
8686
check_array(f)
87+
f.itemset(0, f.item(0) + Integer(2))
88+
check_array(f)
8789
assert isinstance(f.tolist(), list)
8890
f = f.tolist()[0] # Not an array
8991

tests/nada-tests/tests/supported_operations.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ expected_outputs:
8888
out_4:
8989
SecretInteger: 21
9090
out_5:
91-
Integer: 40
91+
Integer: 42

tests/nada-tests/tests/supported_operations_return_types.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ expected_outputs:
8888
out_4:
8989
SecretInteger: 21
9090
out_5:
91-
Integer: 40
91+
Integer: 42

0 commit comments

Comments
 (0)