Skip to content

Commit 9919d93

Browse files
committed
feat: Added matrix multiplication integration (@)
1 parent f6b25a7 commit 9919d93

File tree

7 files changed

+157
-1
lines changed

7 files changed

+157
-1
lines changed

nada_algebra/array.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,20 @@ def __truediv__(
155155
return NadaArray(self.inner / Integer(other))
156156
return NadaArray(self.inner / other)
157157

158+
def __matmul__(self, other: "NadaArray") -> "NadaArray":
159+
"""
160+
Perform matrix multiplication with another NadaArray.
161+
162+
Args:
163+
other (NadaArray): The NadaArray to perform matrix multiplication with.
164+
165+
Returns:
166+
NadaArray: A new NadaArray representing the result of matrix multiplication.
167+
"""
168+
if isinstance(other, NadaArray):
169+
return NadaArray(self.inner @ other.inner)
170+
171+
158172
def dot(self, other: "NadaArray") -> "NadaArray":
159173
"""
160174
Compute the dot product between two NadaArray objects.

tests/matrix_multiplication/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Matrix Multiplication Tutorial
2+
3+
This tutorial shows how to efficiently program a matrix multiplication in Nada using Nada Algebra.
4+
5+
```python
6+
from nada_dsl import *
7+
# Step 0: Nada Algebra is imported with this line
8+
import nada_algebra as na
9+
10+
11+
def nada_main():
12+
# Step 1: We use Nada Algebra wrapper to create "Party0", "Party1" and "Party2"
13+
parties = na.parties(3)
14+
15+
# Step 2: Party0 creates an array of dimension (3 x 3) with name "A"
16+
a = na.array([3, 3], parties[0], "A")
17+
18+
# Step 3: Party1 creates an array of dimension (3 x 3) with name "B"
19+
b = na.array([3, 3], parties[1], "B")
20+
21+
# Step 4: The result is of computing the dot product between the two which is another (3 x 3) matrix
22+
result = a @ b
23+
24+
# Step 5: We can use result.output() to produce the output for Party2 and variable name "my_output"
25+
return result.output(parties[1], "my_output")
26+
27+
```
28+
29+
0. We import Nada algebra using `import nada_algebra as na`.
30+
1. We create an array of parties, with our wrapper using `parties = na.parties(3)` which creates an array of parties named: `Party0`, `Party1` and `Party2`.
31+
2. We create our input array `a` with `na.array([3], parties[0], "A")`, meaning our array will have dimension 3, `Party0` will be in charge of giving its inputs and the name of the variable is `"A"`.
32+
3. We create our input array `b` with `na.array([3], parties[1], "B")`, meaning our array will have dimension 3, `Party1` will be in charge of giving its inputs and the name of the variable is `"B"`.
33+
4. Then, we use the `dot` function to compute the dot product like `a.dot(b)`, which will encompass all the functionality.
34+
5. Finally, we use Nada Algebra to produce the outputs of the array like: `result.output(parties[2], "my_output")` establishing that the output party will be `Party2`and the name of the output variable will be `my_output`.
35+
# How to run the tutorial.
36+
37+
1. First, we need to compile the nada program running: `nada build`.
38+
2. Then, we can test our program is running with: `nada test`.
39+
40+
Inspecting `tests/dot-product.yml`, we see how the inputs for the file are two vectors of 3s:
41+
42+
$ A = (3, 3, 3), B = (3, 3, 3)$
43+
44+
$A \times B = 3 \cdot 3 + 3 \cdot 3 + 3 \cdot 3 = 27$
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name = "matrix_multiplication"
2+
version = "0.1.0"
3+
authors = [""]
4+
5+
[[programs]]
6+
path = "src/main.py"
7+
prime_size = 128
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from nada_dsl import *
2+
3+
# Step 0: Nada Algebra is imported with this line
4+
import nada_algebra as na
5+
6+
7+
def nada_main():
8+
# Step 1: We use Nada Algebra wrapper to create "Party0", "Party1" and "Party2"
9+
parties = na.parties(3)
10+
11+
# Step 2: Party0 creates an array of dimension (3 x 3) with name "A"
12+
a = na.array([3, 3], parties[0], "A")
13+
14+
# Step 3: Party1 creates an array of dimension (3 x 3) with name "B"
15+
b = na.array([3, 3], parties[1], "B")
16+
17+
# Step 4: The result is of computing the dot product between the two which is another (3 x 3) matrix
18+
result = a @ b
19+
20+
# Step 5: We can use result.output() to produce the output for Party2 and variable name "my_output"
21+
return result.output(parties[1], "my_output")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This directory is kept purposely, so that no compilation errors arise.
2+
# Ignore everything in this directory
3+
*
4+
# Except this file
5+
!.gitignore
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
---
2+
program: main
3+
inputs:
4+
secrets:
5+
A_0_0:
6+
SecretInteger: "1"
7+
A_0_1:
8+
SecretInteger: "2"
9+
A_0_2:
10+
SecretInteger: "3"
11+
A_1_0:
12+
SecretInteger: "4"
13+
A_1_1:
14+
SecretInteger: "5"
15+
A_1_2:
16+
SecretInteger: "6"
17+
A_2_0:
18+
SecretInteger: "7"
19+
A_2_1:
20+
SecretInteger: "8"
21+
A_2_2:
22+
SecretInteger: "9"
23+
B_0_0:
24+
SecretInteger: "1"
25+
B_0_1:
26+
SecretInteger: "2"
27+
B_0_2:
28+
SecretInteger: "3"
29+
B_1_0:
30+
SecretInteger: "4"
31+
B_1_1:
32+
SecretInteger: "5"
33+
B_1_2:
34+
SecretInteger: "6"
35+
B_2_0:
36+
SecretInteger: "7"
37+
B_2_1:
38+
SecretInteger: "8"
39+
B_2_2:
40+
SecretInteger: "9"
41+
public_variables: {}
42+
expected_outputs:
43+
my_output_0_0:
44+
SecretInteger: "30"
45+
my_output_0_1:
46+
SecretInteger: "36"
47+
my_output_0_2:
48+
SecretInteger: "42"
49+
my_output_1_0:
50+
SecretInteger: "66"
51+
my_output_1_1:
52+
SecretInteger: "81"
53+
my_output_1_2:
54+
SecretInteger: "96"
55+
my_output_2_0:
56+
SecretInteger: "102"
57+
my_output_2_1:
58+
SecretInteger: "126"
59+
my_output_2_2:
60+
SecretInteger: "150"

tests/test_all.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414
"hstack",
1515
"vstack",
1616
"reveal",
17+
"matrix_multiplication",
1718
]
1819

19-
TESTS = ["tests/" + test for test in TESTS]
20+
EXAMPLES = [
21+
22+
]
23+
24+
TESTS = ["tests/" + test for test in TESTS] + ["examples/" + test for test in EXAMPLES]
2025

2126

2227
@pytest.fixture(params=TESTS)

0 commit comments

Comments
 (0)