Skip to content

Commit e2403a3

Browse files
authored
feat: added rational-specific matrix multiplication (#25)
* feat: added rational specific matrix multiplication * feat: added support for multidimensional rational matrices for multiplications * fix: float_from_rational allow input log_scale * fix: various small fixes
1 parent 879ef1f commit e2403a3

8 files changed

+296
-0
lines changed

nada_algebra/array.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,75 @@ def matmul(self, other: "NadaArray") -> "NadaArray":
243243
NadaArray: A new NadaArray representing the result of matrix multiplication.
244244
"""
245245
if isinstance(other, NadaArray):
246+
if self.is_rational or other.is_rational:
247+
return self.rational_matmul(other)
246248
return NadaArray(np.array(self.inner @ other.inner))
247249
return NadaArray(np.array(self.inner @ other))
248250

251+
def rational_matmul(self, other: "NadaArray") -> "NadaArray":
252+
"""
253+
Perform matrix multiplication with another NadaArray when both have Rational Numbers.
254+
It improves the number of truncations to be needed to the resulting matrix dimensions mxp.
255+
256+
Args:
257+
other (NadaArray): The NadaArray to perform matrix multiplication with.
258+
259+
Returns:
260+
NadaArray: A new NadaArray representing the result of matrix multiplication.
261+
"""
262+
return NadaArray(NadaArray.rational_matmul_recursive(self, other))
263+
264+
@staticmethod
265+
def rational_matmul_recursive(A: "NadaArray", B: "NadaArray") -> "NadaArray":
266+
"""
267+
Perform matrix multiplication with another NadaArray when both have Rational Numbers.
268+
It improves the number of truncations to be needed to the resulting matrix dimensions mxp.
269+
270+
Args:
271+
other (NadaArray): The NadaArray to perform matrix multiplication with.
272+
273+
Returns:
274+
NadaArray: A new NadaArray representing the result of matrix multiplication.
275+
"""
276+
# We check that both have same number of dimensions.
277+
if A.ndim != B.ndim:
278+
raise ValueError(
279+
f"Matrices are not aligned for multiplication: {A.inner.shape} and {B.inner.shape}"
280+
)
281+
282+
# Since both have the same number of dimensions, we now check if they are 2D matrices.
283+
# If they are not, they will pass this check and execute normally.
284+
# Otherwise, we will do matrix contraction (i.e., compute matrix multiplication dimension by dimension).
285+
if A.ndim > 2:
286+
a = [
287+
NadaArray.rational_matmul_recursive(A[i], B[i])
288+
for i in range(A.shape[0])
289+
] # We remove one dimension here.
290+
return np.array(a)
291+
292+
# Get the dimensions of the matrices
293+
(m, n) = A.shape
294+
(n_, p) = B.shape
295+
296+
if n != n_:
297+
raise ValueError(
298+
f"Matrices are not aligned for multiplication: {A.inner.shape} and {B.inner.shape}"
299+
)
300+
301+
# Initialize the result matrix C with zeros
302+
C = np.zeros((m, p), dtype=object)
303+
304+
# Perform matrix multiplication
305+
for i in range(m):
306+
for j in range(p):
307+
for k in range(n):
308+
if k == 0:
309+
C[i][j] = A[i][k].mul_no_rescale(B[k][j])
310+
else:
311+
C[i][j] += A[i][k].mul_no_rescale(B[k][j])
312+
C[i][j] = C[i][j].rescale_down()
313+
return C
314+
249315
def __matmul__(self, other: Any) -> "NadaArray":
250316
"""
251317
Perform matrix multiplication with another NadaArray.

nada_algebra/client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,19 @@ def secret_rational(value: Union[float, int]) -> SecretInteger:
135135
int: The integer representation of the input value.
136136
"""
137137
return SecretInteger(__rational(value))
138+
139+
140+
def float_from_rational(value: int, log_scale: int = None) -> float:
141+
"""
142+
Returns the float representation of the given rational value.
143+
144+
Args:
145+
value (int): The output Rational value to convert.
146+
log_scale (int, optional): The log scale to use for conversion. Defaults to None.
147+
148+
Returns:
149+
float: The float representation of the input value.
150+
"""
151+
if log_scale is not None:
152+
log_scale = get_log_scale()
153+
return value / (1 << log_scale)

tests/nada-tests/nada-project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,11 @@ prime_size = 128
145145
[[programs]]
146146
path = "src/array_statistics.py"
147147
prime_size = 128
148+
149+
[[programs]]
150+
path = "src/matrix_multiplication_rational.py"
151+
prime_size = 128
152+
153+
[[programs]]
154+
path = "src/matrix_multiplication_rational_multidim.py"
155+
prime_size = 128
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from nada_dsl import *
2+
3+
# Step 0: Nada Algebra is imported with this line
4+
import nada_algebra as na
5+
6+
7+
def nada_main():
8+
# Step 1: We use Nada Algebra wrapper to create "Party0", "Party1" and "Party2"
9+
parties = na.parties(3)
10+
11+
# Step 2: Party0 creates an array of dimension (3 x 3) with name "A"
12+
a = na.array([3, 3], parties[0], "A", na.SecretRational)
13+
14+
# Step 3: Party1 creates an array of dimension (3 x 3) with name "B"
15+
b = na.array([3, 3], parties[1], "B", na.SecretRational)
16+
17+
# Step 4: The result is of computing the dot product between the two which is another (3 x 3) matrix
18+
result = a @ b
19+
20+
# Step 5: We can use result.output() to produce the output for Party2 and variable name "my_output"
21+
return result.output(parties[1], "my_output")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from nada_dsl import *
2+
3+
# Step 0: Nada Algebra is imported with this line
4+
import nada_algebra as na
5+
6+
7+
def nada_main():
8+
# Step 1: We use Nada Algebra wrapper to create "Party0", "Party1" and "Party2"
9+
parties = na.parties(3)
10+
11+
# Step 2: Party0 creates an array of dimension (2 x 1 x 1 x 2 x 2 x 2) with name "A"
12+
a = na.array([2, 1, 1, 2, 2, 2], parties[0], "A", na.SecretRational)
13+
14+
# Step 3: Party1 creates an array of dimension (2 x 1 x 1 x 2 x 2 x 2) with name "B"
15+
b = na.array([2, 1, 1, 2, 2, 2], parties[1], "B", na.SecretRational)
16+
17+
# Step 4: The result is of computing the dot product between the two which is another (2 x 1 x 1 x 2 x 2 x 2) matrix
18+
result = a @ b
19+
20+
# Step 5: We can use result.output() to produce the output for Party2 and variable name "my_output"
21+
return result.output(parties[1], "my_output")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
---
2+
program: matrix_multiplication_rational
3+
inputs:
4+
secrets:
5+
A_0_0:
6+
SecretInteger: "65536"
7+
A_0_1:
8+
SecretInteger: "131072"
9+
A_0_2:
10+
SecretInteger: "196608"
11+
A_1_0:
12+
SecretInteger: "262144"
13+
A_1_1:
14+
SecretInteger: "327680"
15+
A_1_2:
16+
SecretInteger: "393216"
17+
A_2_0:
18+
SecretInteger: "458752"
19+
A_2_1:
20+
SecretInteger: "524288"
21+
A_2_2:
22+
SecretInteger: "589824"
23+
B_0_0:
24+
SecretInteger: "65536"
25+
B_0_1:
26+
SecretInteger: "131072"
27+
B_0_2:
28+
SecretInteger: "196608"
29+
B_1_0:
30+
SecretInteger: "262144"
31+
B_1_1:
32+
SecretInteger: "327680"
33+
B_1_2:
34+
SecretInteger: "393216"
35+
B_2_0:
36+
SecretInteger: "458752"
37+
B_2_1:
38+
SecretInteger: "524288"
39+
B_2_2:
40+
SecretInteger: "589824"
41+
public_variables: {}
42+
expected_outputs:
43+
my_output_0_0:
44+
SecretInteger: "1966080"
45+
my_output_0_1:
46+
SecretInteger: "2359296"
47+
my_output_0_2:
48+
SecretInteger: "2752512"
49+
my_output_1_0:
50+
SecretInteger: "4325376"
51+
my_output_1_1:
52+
SecretInteger: "5308416"
53+
my_output_1_2:
54+
SecretInteger: "6291456"
55+
my_output_2_0:
56+
SecretInteger: "6684672"
57+
my_output_2_1:
58+
SecretInteger: "8257536"
59+
my_output_2_2:
60+
SecretInteger: "9830400"
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
---
2+
program: matrix_multiplication_rational_multidim
3+
inputs:
4+
secrets:
5+
A_0_0_0_0_0_0:
6+
SecretInteger: "65536"
7+
A_0_0_0_0_0_1:
8+
SecretInteger: "131072"
9+
A_0_0_0_0_1_0:
10+
SecretInteger: "196608"
11+
A_0_0_0_0_1_1:
12+
SecretInteger: "262144"
13+
A_0_0_0_1_0_0:
14+
SecretInteger: "65536"
15+
A_0_0_0_1_0_1:
16+
SecretInteger: "131072"
17+
A_0_0_0_1_1_0:
18+
SecretInteger: "196608"
19+
A_0_0_0_1_1_1:
20+
SecretInteger: "262144"
21+
A_1_0_0_0_0_0:
22+
SecretInteger: "65536"
23+
A_1_0_0_0_0_1:
24+
SecretInteger: "131072"
25+
A_1_0_0_0_1_0:
26+
SecretInteger: "196608"
27+
A_1_0_0_0_1_1:
28+
SecretInteger: "262144"
29+
A_1_0_0_1_0_0:
30+
SecretInteger: "65536"
31+
A_1_0_0_1_0_1:
32+
SecretInteger: "131072"
33+
A_1_0_0_1_1_0:
34+
SecretInteger: "196608"
35+
A_1_0_0_1_1_1:
36+
SecretInteger: "262144"
37+
B_0_0_0_0_0_0:
38+
SecretInteger: "65536"
39+
B_0_0_0_0_0_1:
40+
SecretInteger: "131072"
41+
B_0_0_0_0_1_0:
42+
SecretInteger: "196608"
43+
B_0_0_0_0_1_1:
44+
SecretInteger: "262144"
45+
B_0_0_0_1_0_0:
46+
SecretInteger: "65536"
47+
B_0_0_0_1_0_1:
48+
SecretInteger: "131072"
49+
B_0_0_0_1_1_0:
50+
SecretInteger: "196608"
51+
B_0_0_0_1_1_1:
52+
SecretInteger: "262144"
53+
B_1_0_0_0_0_0:
54+
SecretInteger: "65536"
55+
B_1_0_0_0_0_1:
56+
SecretInteger: "131072"
57+
B_1_0_0_0_1_0:
58+
SecretInteger: "196608"
59+
B_1_0_0_0_1_1:
60+
SecretInteger: "262144"
61+
B_1_0_0_1_0_0:
62+
SecretInteger: "65536"
63+
B_1_0_0_1_0_1:
64+
SecretInteger: "131072"
65+
B_1_0_0_1_1_0:
66+
SecretInteger: "196608"
67+
B_1_0_0_1_1_1:
68+
SecretInteger: "262144"
69+
public_variables: {}
70+
expected_outputs:
71+
my_output_0_0_0_0_0_0:
72+
SecretInteger: "458752"
73+
my_output_0_0_0_0_0_1:
74+
SecretInteger: "655360"
75+
my_output_0_0_0_0_1_0:
76+
SecretInteger: "983040"
77+
my_output_0_0_0_0_1_1:
78+
SecretInteger: "1441792"
79+
my_output_0_0_0_1_0_0:
80+
SecretInteger: "458752"
81+
my_output_0_0_0_1_0_1:
82+
SecretInteger: "655360"
83+
my_output_0_0_0_1_1_0:
84+
SecretInteger: "983040"
85+
my_output_0_0_0_1_1_1:
86+
SecretInteger: "1441792"
87+
my_output_1_0_0_0_0_0:
88+
SecretInteger: "458752"
89+
my_output_1_0_0_0_0_1:
90+
SecretInteger: "655360"
91+
my_output_1_0_0_0_1_0:
92+
SecretInteger: "983040"
93+
my_output_1_0_0_0_1_1:
94+
SecretInteger: "1441792"
95+
my_output_1_0_0_1_0_0:
96+
SecretInteger: "458752"
97+
my_output_1_0_0_1_0_1:
98+
SecretInteger: "655360"
99+
my_output_1_0_0_1_1_0:
100+
SecretInteger: "983040"
101+
my_output_1_0_0_1_1_1:
102+
SecretInteger: "1441792"

tests/test_all.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"array_attributes",
3838
"functional_operations",
3939
"array_statistics",
40+
"matrix_multiplication_rational",
41+
"matrix_multiplication_rational_multidim",
4042
# Not supported yet
4143
# "unsigned_matrix_inverse",
4244
# "private_inverse"

0 commit comments

Comments
 (0)