Skip to content

Commit 3f09e7c

Browse files
Remove unnecessary and type-problematic _numop methods from nlinalg Ops
1 parent 2368ed3 commit 3f09e7c

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

aesara/tensor/nlinalg.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ class Eig(Op):
237237
238238
"""
239239

240-
_numop = staticmethod(np.linalg.eig)
241240
__props__: Union[Tuple, Tuple[str]] = ()
242241

243242
def make_node(self, x):
@@ -250,7 +249,7 @@ def make_node(self, x):
250249
def perform(self, node, inputs, outputs):
251250
(x,) = inputs
252251
(w, v) = outputs
253-
w[0], v[0] = (z.astype(x.dtype) for z in self._numop(x))
252+
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
254253

255254
def infer_shape(self, fgraph, node, shapes):
256255
n = shapes[0][0]
@@ -266,7 +265,6 @@ class Eigh(Eig):
266265
267266
"""
268267

269-
_numop = staticmethod(np.linalg.eigh)
270268
__props__ = ("UPLO",)
271269

272270
def __init__(self, UPLO="L"):
@@ -281,15 +279,15 @@ def make_node(self, x):
281279
# LAPACK. Rather than trying to reproduce the (rather
282280
# involved) logic, we just probe linalg.eigh with a trivial
283281
# input.
284-
w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name
282+
w_dtype = np.linalg.eigh([[np.dtype(x.dtype).type()]])[0].dtype.name
285283
w = vector(dtype=w_dtype)
286284
v = matrix(dtype=w_dtype)
287285
return Apply(self, [x], [w, v])
288286

289287
def perform(self, node, inputs, outputs):
290288
(x,) = inputs
291289
(w, v) = outputs
292-
w[0], v[0] = self._numop(x, self.UPLO)
290+
w[0], v[0] = np.linalg.eigh(x, self.UPLO)
293291

294292
def grad(self, inputs, g_outputs):
295293
r"""The gradient function should return
@@ -412,7 +410,6 @@ class QRFull(Op):
412410
413411
"""
414412

415-
_numop = staticmethod(np.linalg.qr)
416413
__props__ = ("mode",)
417414

418415
def __init__(self, mode):
@@ -444,7 +441,7 @@ def make_node(self, x):
444441
def perform(self, node, inputs, outputs):
445442
(x,) = inputs
446443
assert x.ndim == 2, "The input of qr function should be a matrix."
447-
res = self._numop(x, self.mode)
444+
res = np.linalg.qr(x, self.mode)
448445
if self.mode != "r":
449446
outputs[0][0], outputs[1][0] = res
450447
else:
@@ -513,7 +510,6 @@ class SVD(Op):
513510
"""
514511

515512
# See doc in the docstring of the function just after this class.
516-
_numop = staticmethod(np.linalg.svd)
517513
__props__ = ("full_matrices", "compute_uv")
518514

519515
def __init__(self, full_matrices=True, compute_uv=True):
@@ -541,10 +537,10 @@ def perform(self, node, inputs, outputs):
541537
assert x.ndim == 2, "The input of svd function should be a matrix."
542538
if self.compute_uv:
543539
u, s, vt = outputs
544-
u[0], s[0], vt[0] = self._numop(x, self.full_matrices, self.compute_uv)
540+
u[0], s[0], vt[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
545541
else:
546542
(s,) = outputs
547-
s[0] = self._numop(x, self.full_matrices, self.compute_uv)
543+
s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
548544

549545
def infer_shape(self, fgraph, node, shapes):
550546
(x_shape,) = shapes
@@ -696,7 +692,6 @@ class TensorInv(Op):
696692
Aesara utilization of numpy.linalg.tensorinv;
697693
"""
698694

699-
_numop = staticmethod(np.linalg.tensorinv)
700695
__props__ = ("ind",)
701696

702697
def __init__(self, ind=2):
@@ -710,7 +705,7 @@ def make_node(self, a):
710705
def perform(self, node, inputs, outputs):
711706
(a,) = inputs
712707
(x,) = outputs
713-
x[0] = self._numop(a, self.ind)
708+
x[0] = np.linalg.tensorinv(a, self.ind)
714709

715710
def infer_shape(self, fgraph, node, shapes):
716711
sp = shapes[0][self.ind :] + shapes[0][: self.ind]
@@ -756,7 +751,6 @@ class TensorSolve(Op):
756751
757752
"""
758753

759-
_numop = staticmethod(np.linalg.tensorsolve)
760754
__props__ = ("axes",)
761755

762756
def __init__(self, axes=None):
@@ -770,12 +764,9 @@ def make_node(self, a, b):
770764
return Apply(self, [a, b], [x])
771765

772766
def perform(self, node, inputs, outputs):
773-
(
774-
a,
775-
b,
776-
) = inputs
767+
(a, b) = inputs
777768
(x,) = outputs
778-
x[0] = self._numop(a, b, self.axes)
769+
x[0] = np.linalg.tensorsolve(a, b, self.axes)
779770

780771

781772
def tensorsolve(a, b, axes=None):

0 commit comments

Comments
 (0)