Skip to content

Commit 9791ee6

Browse files
authored
fix: rational dot product for 1-dim matrices (#27)
* fix: new way of doing matrix multiplication with rationals * chore: reformated and added code comments * fix: linting
1 parent 406c34c commit 9791ee6

11 files changed

+314
-70
lines changed

nada_algebra/array.py

Lines changed: 10 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Integer,
2020
UnsignedInteger,
2121
)
22+
2223
from nada_algebra.types import (
2324
Rational,
2425
SecretRational,
@@ -27,6 +28,9 @@
2728
secret_rational,
2829
get_log_scale,
2930
)
31+
32+
from nada_algebra.context import UnsafeArithmeticSession
33+
3034
from nada_algebra.utils import copy_metadata
3135

3236

@@ -270,59 +274,11 @@ def rational_matmul(self, other: "NadaArray") -> "NadaArray":
270274
Returns:
271275
NadaArray: A new NadaArray representing the result of matrix multiplication.
272276
"""
273-
return NadaArray(NadaArray.rational_matmul_recursive(self, other))
274-
275-
@staticmethod
276-
def rational_matmul_recursive(A: "NadaArray", B: "NadaArray") -> "NadaArray":
277-
"""
278-
Perform matrix multiplication with another NadaArray when both have Rational Numbers.
279-
It improves the number of truncations to be needed to the resulting matrix dimensions mxp.
280-
281-
Args:
282-
other (NadaArray): The NadaArray to perform matrix multiplication with.
283-
284-
Returns:
285-
NadaArray: A new NadaArray representing the result of matrix multiplication.
286-
"""
287-
# We check that both have same number of dimensions.
288-
if A.ndim != B.ndim:
289-
raise ValueError(
290-
f"Matrices are not aligned for multiplication: {A.inner.shape} and {B.inner.shape}"
277+
with UnsafeArithmeticSession():
278+
return NadaArray(np.array(self.inner @ other.inner)).apply(
279+
lambda x: x.rescale_down()
291280
)
292281

293-
# Since both have the same number of dimensions, we now check if they are 2D matrices.
294-
# If they are not, they will pass this check and execute normally.
295-
# Otherwise, we will do matrix contraction (i.e., compute matrix multiplication dimension by dimension).
296-
if A.ndim > 2:
297-
a = [
298-
NadaArray.rational_matmul_recursive(A[i], B[i])
299-
for i in range(A.shape[0])
300-
] # We remove one dimension here.
301-
return np.array(a)
302-
303-
# Get the dimensions of the matrices
304-
(m, n) = A.shape
305-
(n_, p) = B.shape
306-
307-
if n != n_:
308-
raise ValueError(
309-
f"Matrices are not aligned for multiplication: {A.inner.shape} and {B.inner.shape}"
310-
)
311-
312-
# Initialize the result matrix C with zeros
313-
C = np.zeros((m, p), dtype=object)
314-
315-
# Perform matrix multiplication
316-
for i in range(m):
317-
for j in range(p):
318-
for k in range(n):
319-
if k == 0:
320-
C[i][j] = A[i][k].mul_no_rescale(B[k][j])
321-
else:
322-
C[i][j] += A[i][k].mul_no_rescale(B[k][j])
323-
C[i][j] = C[i][j].rescale_down()
324-
return C
325-
326282
def __matmul__(self, other: Any) -> "NadaArray":
327283
"""
328284
Perform matrix multiplication with another NadaArray.
@@ -357,6 +313,9 @@ def dot(self, other: "NadaArray") -> "NadaArray":
357313
Returns:
358314
NadaArray: A new NadaArray representing the dot product result.
359315
"""
316+
if self.is_rational or other.is_rational:
317+
return self.rational_matmul(other)
318+
360319
return NadaArray(self.inner.dot(other.inner))
361320

362321
def hstack(self, other: "NadaArray") -> "NadaArray":

nada_algebra/context.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from nada_algebra.types import SecretRational, Rational, _NadaRational
2+
3+
4+
class UnsafeArithmeticSession:
5+
"""
6+
A context manager that temporarily modifies the behavior of arithmetic operations
7+
for Rational and SecretRational types, disabling rescaling for multiplication and division.
8+
9+
Attributes:
10+
mul_rational (function): Original __mul__ method of Rational.
11+
mul_secret_rational (function): Original __mul__ method of SecretRational.
12+
truediv_rational (function): Original __truediv__ method of Rational.
13+
truediv_secret_rational (function): Original __truediv__ method of SecretRational.
14+
"""
15+
16+
def __init__(self):
17+
"""
18+
Initializes the UnsafeArithmeticSession by storing the original
19+
multiplication and division methods of Rational and SecretRational.
20+
"""
21+
self.mul_rational = Rational.__mul__
22+
self.mul_secret_rational = SecretRational.__mul__
23+
24+
self.truediv_rational = Rational.__truediv__
25+
self.truediv_secret_rational = SecretRational.__truediv__
26+
27+
def __enter__(self):
28+
"""
29+
Enters the context, temporarily replacing the multiplication and division
30+
methods of Rational and SecretRational to disable rescaling.
31+
"""
32+
33+
def mul_no_rescale_wrapper(self: Rational, other: _NadaRational):
34+
"""
35+
Wrapper for Rational.__mul__ that disables rescaling.
36+
37+
Args:
38+
self (Rational): The Rational instance.
39+
other (_NadaRational): The other operand.
40+
41+
Returns:
42+
Rational: Result of the multiplication without rescaling.
43+
"""
44+
return Rational.mul_no_rescale(self, other, ignore_scale=True)
45+
46+
def secret_mul_no_rescale_wrapper(self: SecretRational, other: _NadaRational):
47+
"""
48+
Wrapper for SecretRational.__mul__ that disables rescaling.
49+
50+
Args:
51+
self (SecretRational): The SecretRational instance.
52+
other (_NadaRational): The other operand.
53+
54+
Returns:
55+
SecretRational: Result of the multiplication without rescaling.
56+
"""
57+
return SecretRational.mul_no_rescale(self, other, ignore_scale=True)
58+
59+
def divide_no_rescale_wrapper(self: Rational, other: _NadaRational):
60+
"""
61+
Wrapper for Rational.__truediv__ that disables rescaling.
62+
63+
Args:
64+
self (Rational): The Rational instance.
65+
other (_NadaRational): The other operand.
66+
67+
Returns:
68+
Rational: Result of the division without rescaling.
69+
"""
70+
return Rational.divide_no_rescale(self, other, ignore_scale=True)
71+
72+
def secret_divide_no_rescale_wrapper(
73+
self: SecretRational, other: _NadaRational
74+
):
75+
"""
76+
Wrapper for SecretRational.__truediv__ that disables rescaling.
77+
78+
Args:
79+
self (SecretRational): The SecretRational instance.
80+
other (_NadaRational): The other operand.
81+
82+
Returns:
83+
SecretRational: Result of the division without rescaling.
84+
"""
85+
return SecretRational.divide_no_rescale(self, other, ignore_scale=True)
86+
87+
Rational.__mul__ = mul_no_rescale_wrapper
88+
SecretRational.__mul__ = secret_mul_no_rescale_wrapper
89+
Rational.__truediv__ = divide_no_rescale_wrapper
90+
SecretRational.__truediv__ = secret_divide_no_rescale_wrapper
91+
92+
def __exit__(self, exc_type, exc_val, exc_tb):
93+
"""
94+
Exits the context, restoring the original multiplication and division methods
95+
of Rational and SecretRational.
96+
97+
Args:
98+
exc_type (type): Exception type if an exception occurred, else None.
99+
exc_val (Exception): Exception instance if an exception occurred, else None.
100+
exc_tb (traceback): Traceback object if an exception occurred, else None.
101+
"""
102+
# Restore the original __mul__ method
103+
Rational.__mul__ = self.mul_rational
104+
SecretRational.__mul__ = self.mul_secret_rational
105+
106+
# Restore the original __truediv__ method
107+
Rational.__truediv__ = self.truediv_rational
108+
SecretRational.__truediv__ = self.truediv_secret_rational

nada_algebra/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Additional special data types"""
22

33
import warnings
4+
from functools import partial
45
import numpy as np
56

67
import nada_dsl as dsl

tests/nada-tests/nada-project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,8 @@ prime_size = 128
152152

153153
[[programs]]
154154
path = "src/matrix_multiplication_rational_multidim.py"
155+
prime_size = 128
156+
157+
[[programs]]
158+
path = "src/dot_product_rational.py"
155159
prime_size = 128
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from nada_dsl import *
2+
import nada_algebra as na
3+
4+
5+
def nada_main():
6+
parties = na.parties(3)
7+
8+
a = na.array((3,), parties[0], "A", na.SecretRational)
9+
b = na.array((3,), parties[1], "B", na.SecretRational)
10+
c = na.ones((3,), na.Rational)
11+
12+
result = a.dot(b)
13+
14+
result_b = a @ b
15+
16+
result_c = a.dot(c)
17+
18+
result_d = a @ c
19+
20+
return (
21+
result.output(parties[1], "my_output_a")
22+
+ result_b.output(parties[1], "my_output_b")
23+
+ result_c.output(parties[1], "my_output_c")
24+
+ result_d.output(parties[1], "my_output_d")
25+
)
Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
from nada_dsl import *
2-
3-
# Step 0: Nada Algebra is imported with this line
42
import nada_algebra as na
53

64

75
def nada_main():
8-
# Step 1: We use Nada Algebra wrapper to create "Party0", "Party1" and "Party2"
96
parties = na.parties(3)
107

11-
# Step 2: Party0 creates an array of dimension (3 x 3) with name "A"
128
a = na.array([3, 3], parties[0], "A", na.SecretRational)
139

14-
# Step 3: Party1 creates an array of dimension (3 x 3) with name "B"
1510
b = na.array([3, 3], parties[1], "B", na.SecretRational)
1611

17-
# Step 4: The result is of computing the dot product between the two which is another (3 x 3) matrix
18-
result = a @ b
12+
c = na.array((3,), parties[2], "C", na.SecretRational)
13+
14+
d = na.ones([3, 3], na.Rational)
15+
16+
result_a = a @ b
17+
18+
result_b = a @ c
19+
20+
result_c = a @ d
21+
22+
result_d = d @ a
1923

20-
# Step 5: We can use result.output() to produce the output for Party2 and variable name "my_output"
21-
return result.output(parties[1], "my_output")
24+
return (
25+
result_a.output(parties[1], "my_output")
26+
+ result_b.output(parties[1], "my_output_b")
27+
+ result_c.output(parties[1], "my_output_c")
28+
+ result_d.output(parties[1], "my_output_d")
29+
)
Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
from nada_dsl import *
2-
3-
# Step 0: Nada Algebra is imported with this line
42
import nada_algebra as na
53

64

75
def nada_main():
8-
# Step 1: We use Nada Algebra wrapper to create "Party0", "Party1" and "Party2"
96
parties = na.parties(3)
107

11-
# Step 2: Party0 creates an array of dimension (2 x 1 x 1 x 2 x 2 x 2) with name "A"
128
a = na.array([2, 1, 1, 2, 2, 2], parties[0], "A", na.SecretRational)
139

14-
# Step 3: Party1 creates an array of dimension (2 x 1 x 1 x 2 x 2 x 2) with name "B"
1510
b = na.array([2, 1, 1, 2, 2, 2], parties[1], "B", na.SecretRational)
1611

17-
# Step 4: The result is of computing the dot product between the two which is another (2 x 1 x 1 x 2 x 2 x 2) matrix
18-
result = a @ b
12+
c = na.array((2,), parties[2], "C", na.SecretRational)
13+
14+
d = na.ones([2, 1, 1, 2, 2, 2], na.Rational)
15+
16+
result_a = a @ b
17+
18+
result_b = a @ c
19+
20+
result_c = a @ d
21+
22+
result_d = d @ a
1923

20-
# Step 5: We can use result.output() to produce the output for Party2 and variable name "my_output"
21-
return result.output(parties[1], "my_output")
24+
return (
25+
result_a.output(parties[1], "my_output")
26+
+ result_b.output(parties[1], "my_output_b")
27+
+ result_c.output(parties[1], "my_output_c")
28+
+ result_d.output(parties[1], "my_output_d")
29+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
---
2+
program: dot_product_rational
3+
inputs:
4+
secrets:
5+
A_0:
6+
SecretInteger: "65536"
7+
A_1:
8+
SecretInteger: "131072"
9+
A_2:
10+
SecretInteger: "196608"
11+
B_0:
12+
SecretInteger: "65536"
13+
B_1:
14+
SecretInteger: "131072"
15+
B_2:
16+
SecretInteger: "196608"
17+
public_variables: {}
18+
expected_outputs:
19+
my_output_a_0:
20+
SecretInteger: "917504"
21+
my_output_b_0:
22+
SecretInteger: "917504"
23+
my_output_c_0:
24+
SecretInteger: "393216"
25+
my_output_d_0:
26+
SecretInteger: "393216"

tests/nada-tests/tests/matrix_multiplication_rational.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ inputs:
3838
SecretInteger: "524288"
3939
B_2_2:
4040
SecretInteger: "589824"
41+
C_0:
42+
SecretInteger: "65536"
43+
C_1:
44+
SecretInteger: "0"
45+
C_2:
46+
SecretInteger: "0"
4147
public_variables: {}
4248
expected_outputs:
4349
my_output_0_0:
@@ -58,3 +64,33 @@ expected_outputs:
5864
SecretInteger: "8257536"
5965
my_output_2_2:
6066
SecretInteger: "9830400"
67+
my_output_b_0:
68+
SecretInteger: "65536"
69+
my_output_b_1:
70+
SecretInteger: "262144"
71+
my_output_b_2:
72+
SecretInteger: "458752"
73+
my_output_c_0_0:
74+
SecretInteger: "393216"
75+
my_output_c_0_1:
76+
SecretInteger: "393216"
77+
my_output_c_0_2:
78+
SecretInteger: "393216"
79+
my_output_c_1_0:
80+
SecretInteger: "983040"
81+
my_output_c_1_1:
82+
SecretInteger: "983040"
83+
my_output_c_1_2:
84+
SecretInteger: "983040"
85+
my_output_c_2_0:
86+
SecretInteger: "1572864"
87+
my_output_c_2_1:
88+
SecretInteger: "1572864"
89+
my_output_c_2_2:
90+
SecretInteger: "1572864"
91+
my_output_d_0_0:
92+
SecretInteger: "786432"
93+
my_output_d_1_0:
94+
SecretInteger: "786432"
95+
my_output_d_2_0:
96+
SecretInteger: "786432"

0 commit comments

Comments
 (0)