Skip to content

Commit 17a2ec0

Browse files
Feature/custom party names (#50)
* Add overwriteable party names
1 parent 07e03a8 commit 17a2ec0

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

nada_numpy/funcs.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
and manipulation of arrays and party objects.
44
"""
55

6-
from typing import Any, Callable, List, Sequence, Tuple, Union
6+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
77

88
import numpy as np
99
from nada_dsl import (Boolean, Integer, Output, Party, PublicInteger,
@@ -60,18 +60,30 @@
6060
]
6161

6262

63-
def parties(num: int, prefix: str = "Party") -> List[Party]:
63+
def parties(num: int, party_names: Optional[List[str]] = None) -> List[Party]:
6464
"""
6565
Create a list of Party objects with specified names.
6666
6767
Args:
6868
num (int): The number of parties to create.
69-
prefix (str, optional): The prefix to use for party names. Defaults to "Party".
69+
party_names (List[str], optional): Party names to use. Defaults to None.
70+
71+
Raises:
72+
ValueError: Raised when incorrect number of party names is supplied.
7073
7174
Returns:
72-
List[Party]: A list of Party objects with names in the format "{prefix}{i}".
75+
List[Party]: A list of Party objects.
7376
"""
74-
return [Party(name=f"{prefix}{i}") for i in range(num)]
77+
if party_names is None:
78+
party_names = [f"Party{i}" for i in range(num)]
79+
80+
if len(party_names) != num:
81+
num_supplied_parties = len(party_names)
82+
raise ValueError(
83+
f"Incorrect number of party names. Expected {num}, received {num_supplied_parties}"
84+
)
85+
86+
return [Party(name=party_name) for party_name in party_names]
7587

7688

7789
def __from_numpy(arr: np.ndarray, nada_type: NadaCleartextNumber) -> List:

0 commit comments

Comments
 (0)