Skip to content

Commit 10b9ffe

Browse files
Various small bug fixes (#12)
Various small bug fixes
1 parent cf823f1 commit 10b9ffe

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

nada_algebra/array.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,24 @@ def __mul__(self, other: _NadaOperand) -> "NadaArray":
163163
return NadaArray(self.inner * Integer(other))
164164
return NadaArray(self.inner * other)
165165

166+
def __pow__(self, other: int) -> "NadaArray":
167+
"""Raises NadaArray to a power.
168+
169+
Args:
170+
other (int): Power value.
171+
172+
Returns:
173+
NadaArray: Result NadaArray.
174+
"""
175+
if not isinstance(other, int):
176+
raise TypeError(
177+
"Cannot raise `NadaArray` to power of type `%s`" % type(other).__name__
178+
)
179+
result = self.copy()
180+
for _ in range(other - 1):
181+
result = result * result
182+
return result
183+
166184
def __truediv__(self, other: _NadaOperand) -> "NadaArray":
167185
"""
168186
Perform element-wise division with broadcasting.
@@ -189,8 +207,11 @@ def __matmul__(self, other: "NadaArray") -> "NadaArray":
189207
Returns:
190208
NadaArray: A new NadaArray representing the result of matrix multiplication.
191209
"""
192-
if isinstance(other, NadaArray):
193-
return NadaArray(self.inner @ other.inner)
210+
return NadaArray(self.inner @ other.inner)
211+
212+
@property
213+
def ndim(self) -> int:
214+
return len(self.shape)
194215

195216
def dot(self, other: "NadaArray") -> "NadaArray":
196217
"""
@@ -295,10 +316,17 @@ def output_array(array: np.ndarray, party: Party, prefix: str) -> list:
295316
),
296317
):
297318
return [Output(array, f"{prefix}_0", party)]
319+
elif isinstance(array, (Rational, SecretRational)):
320+
return [Output(array.value, f"{prefix}_0", party)]
298321

299322
if len(array.shape) == 1:
300323
return [
301-
Output(array[i], f"{prefix}_{i}", party) for i in range(array.shape[0])
324+
(
325+
Output(array[i].value, f"{prefix}_{i}", party)
326+
if isinstance(array[i], (Rational, SecretRational))
327+
else Output(array[i], f"{prefix}_{i}", party)
328+
)
329+
for i in range(array.shape[0])
302330
]
303331
return [
304332
v

nada_algebra/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def from_number(
9090
Returns:
9191
Rational: Instantiated wrapper around number.
9292
"""
93-
if (value is None) or (value == np.nan):
93+
if value is None:
9494
raise ValueError("Cannot convert `%s` to Rational." % value)
9595

9696
value = value.item() if isinstance(value, np.floating) else value
@@ -308,13 +308,13 @@ def rescale(value: NadaType, scale: UnsignedInteger, direction: str) -> NadaType
308308
try:
309309
return value << scale
310310
except:
311-
return value * (1 << scale)
311+
return value * Integer(1 << scale)
312312
elif direction == "down":
313313
# TODO: remove try block when rshift implemented for every NadaType
314314
try:
315315
return value >> scale
316316
except:
317-
return value / (1 << scale)
317+
return value / Integer(1 << scale)
318318

319319
raise ValueError(
320320
'Invalid scaling direction `%s`. Expected "up" or "down"' % direction

0 commit comments

Comments
 (0)