Skip to content

Commit ac4dc01

Browse files
authored
feat: Added get_item, get_attr and shape (#5)
1 parent bcdea7e commit ac4dc01

File tree

30 files changed

+462
-11
lines changed

30 files changed

+462
-11
lines changed

nada_algebra/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""This is the __init__.py module"""
22

3+
from nada_algebra.array import NadaArray
34
from nada_algebra.funcs import *

nada_algebra/array.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,46 @@ class NadaArray:
3131

3232
inner: np.ndarray
3333

34+
def __getitem__(self, item):
35+
"""
36+
Get an item from the array.
37+
38+
Args:
39+
item: The item to retrieve.
40+
41+
Returns:
42+
NadaArray: A new NadaArray representing the retrieved item.
43+
"""
44+
if len(self.inner.shape) == 1:
45+
return self.inner[item]
46+
return NadaArray(self.inner[item])
47+
48+
def __setitem__(self, key, value):
49+
"""
50+
Set an item in the array.
51+
52+
Args:
53+
key: The key to set.
54+
value: The value to set.
55+
"""
56+
if isinstance(value, NadaArray):
57+
# print("NadaArray")
58+
self.inner[key] = value.inner
59+
else:
60+
self.inner[key] = value
61+
62+
def __getattr__(self, name: str):
63+
"""
64+
Get an attribute from the array.
65+
66+
Args:
67+
name (str): The attribute name.
68+
69+
Returns:
70+
NadaArray: A new NadaArray representing the retrieved attribute.
71+
"""
72+
return getattr(self.inner, name)
73+
3474
def __add__(
3575
self,
3676
other: Union[
@@ -268,7 +308,17 @@ def output_array(array: np.ndarray, party: Party, prefix: str) -> list:
268308
Returns:
269309
list: A list of Output objects.
270310
"""
271-
if isinstance(array, (SecretInteger, SecretUnsignedInteger)):
311+
if isinstance(
312+
array,
313+
(
314+
SecretInteger,
315+
SecretUnsignedInteger,
316+
PublicInteger,
317+
PublicUnsignedInteger,
318+
Integer,
319+
UnsignedInteger,
320+
),
321+
):
272322
return [Output(array, f"{prefix}_0", party)]
273323

274324
if len(array.shape) == 1:

nada_algebra/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
secret and public variable integers and generating named party objects and input dictionaries.
44
"""
55

6+
from typing import Union
67
from py_nillion_client import (
78
SecretInteger,
89
SecretUnsignedInteger,
@@ -29,12 +30,12 @@ def parties(num: int, prefix: str = "Party") -> list:
2930
def array(
3031
arr: np.ndarray,
3132
prefix: str,
32-
nada_type: (
33-
type(SecretInteger)
34-
| type(SecretUnsignedInteger)
35-
| type(PublicVariableInteger)
36-
| type(PublicVariableUnsignedInteger)
37-
) = SecretInteger,
33+
nada_type: Union[
34+
SecretInteger,
35+
SecretUnsignedInteger,
36+
PublicVariableInteger,
37+
PublicVariableUnsignedInteger,
38+
] = SecretInteger,
3839
) -> dict:
3940
"""
4041
Recursively generates a dictionary of Nillion input objects for each element

nada_algebra/funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ def output(arr: NadaArray, party: Party, prefix: str):
141141
Returns:
142142
list: A list of Output objects.
143143
"""
144-
return arr.output(party, prefix)
144+
return NadaArray.output_array(arr, party, prefix)

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/gauss_jordan/nada-project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name = "matrix_inverse"
2+
version = "0.1.0"
3+
authors = [""]
4+
5+
[[programs]]
6+
path = "src/main.py"
7+
prime_size = 64

tests/gauss_jordan/src/main.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from nada_dsl import *
2+
import numpy as np
3+
4+
import nada_algebra as na
5+
from nada_algebra.array import NadaArray
6+
7+
# from nada_crypto import random_lu_matrix, public_modular_inverse
8+
9+
LOG_SCALE = 16
10+
SCALE = 1 << LOG_SCALE
11+
PRIME_64 = 18446744072637906947
12+
PRIME_128 = 340282366920938463463374607429104828419
13+
PRIME_256 = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF98C00003
14+
PRIME = PRIME_64
15+
16+
17+
def public_modular_inverse(
18+
value: Integer | UnsignedInteger, modulo: int
19+
) -> PublicUnsignedInteger | UnsignedInteger:
20+
"""
21+
Calculates the modular inverse of a value with respect to a prime modulus.
22+
23+
Args:
24+
`value`: The value for which the modular inverse is to be calculated.
25+
`modulo`: The prime modulo with respect to which the modular inverse is to be calculated.
26+
27+
Returns:
28+
The modular inverse of the value with respect to the modulo.
29+
30+
Raises:
31+
Exception: If the input type is not a `PublicUnsignedInteger` or `UnsignedInteger`.
32+
"""
33+
return value ** UnsignedInteger(modulo - 2)
34+
35+
36+
def gauss_jordan_zn(mat: na.NadaArray, modulo: int):
37+
"""
38+
Perform Gauss-Jordan elimination on Z_n on a given matrix.
39+
40+
Parameters:
41+
- `matrix` (numpy.ndarray): The input matrix to perform Gauss-Jordan elimination on.
42+
- `modulo` (int): The modulo representing the field `Z_n`
43+
44+
Returns:
45+
numpy.ndarray: The reduced row echelon form of the input matrix.
46+
"""
47+
48+
# Make a copy of the matrix to avoid modifying the original
49+
rows = mat.inner.shape[0]
50+
cols = mat.inner.shape[1]
51+
52+
# Forward elimination
53+
for i in range(rows):
54+
# Find pivot row
55+
pivot_row = i
56+
while pivot_row < rows and (mat[pivot_row][i] == UnsignedInteger(0)) is Boolean(
57+
True
58+
):
59+
pivot_row += 1
60+
61+
# Swap pivot row with current row
62+
mat[[i, pivot_row]] = mat[[pivot_row, i]]
63+
64+
# Scale pivot row to have leading 1
65+
diagonal_element = mat[i][i]
66+
pivot_inv = public_modular_inverse(diagonal_element, modulo)
67+
68+
mat[i] = mat[i] * pivot_inv
69+
70+
# Perform row operations to eliminate entries below pivot
71+
for j in range(i + 1, rows):
72+
factor = mat[j][i]
73+
mat[j] -= mat[i] * factor
74+
75+
# Backward elimination
76+
for i in range(rows - 1, -1, -1):
77+
for j in range(i - 1, -1, -1):
78+
factor = mat[j][i]
79+
mat[j] -= mat[i] * factor
80+
81+
return mat
82+
83+
84+
def nada_main():
85+
parties = na.parties(3)
86+
87+
A = na.array([3, 3], parties[0], "A", nada_type=SecretUnsignedInteger)
88+
89+
A = A.reveal()
90+
A_inv = gauss_jordan_zn(A, PRIME)
91+
outputs = na.output(A_inv, parties[2], "my_output")
92+
93+
return outputs

tests/gauss_jordan/target/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This directory is kept purposely, so that no compilation errors arise.
2+
# Ignore everything in this directory
3+
*
4+
# Except this file
5+
!.gitignore
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
---
2+
program: main
3+
inputs:
4+
secrets:
5+
A_0_0:
6+
SecretUnsignedInteger: "2"
7+
A_0_1:
8+
SecretUnsignedInteger: "4"
9+
A_0_2:
10+
SecretUnsignedInteger: "6"
11+
A_1_0:
12+
SecretUnsignedInteger: "1"
13+
A_1_1:
14+
SecretUnsignedInteger: "3"
15+
A_1_2:
16+
SecretUnsignedInteger: "5"
17+
A_2_0:
18+
SecretUnsignedInteger: "3"
19+
A_2_1:
20+
SecretUnsignedInteger: "1"
21+
A_2_2:
22+
SecretUnsignedInteger: "2"
23+
public_variables:
24+
B:
25+
UnsignedInteger: "456"
26+
expected_outputs:
27+
my_output_0_0:
28+
UnsignedInteger: "1"
29+
my_output_0_1:
30+
UnsignedInteger: "0"
31+
my_output_0_2:
32+
UnsignedInteger: "0"
33+
my_output_1_0:
34+
UnsignedInteger: "0"
35+
my_output_1_1:
36+
UnsignedInteger: "1"
37+
my_output_1_2:
38+
UnsignedInteger: "0"
39+
my_output_2_0:
40+
UnsignedInteger: "0"
41+
my_output_2_1:
42+
UnsignedInteger: "0"
43+
my_output_2_2:
44+
UnsignedInteger: "1"

tests/get_attr/nada-project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name = "get_item"
2+
version = "0.1.0"
3+
authors = [""]
4+
5+
[[programs]]
6+
path = "src/main.py"
7+
prime_size = 128

0 commit comments

Comments
 (0)