Skip to content

Commit 1de1169

Browse files
committed
MAINT: signal: bilinear_zpk array API
1 parent 6f7ee75 commit 1de1169

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

scipy/signal/_filter_design.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,8 +2864,11 @@ def bilinear_zpk(z, p, k, fs):
28642864
>>> plt.ylabel('Amplitude [dB]')
28652865
>>> plt.grid(True)
28662866
"""
2867-
z = atleast_1d(z)
2868-
p = atleast_1d(p)
2867+
xp = array_namespace(z, p)
2868+
2869+
z, p = map(xp.asarray, (z, p))
2870+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
2871+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
28692872

28702873
fs = _validate_fs(fs, allow_none=False)
28712874

@@ -2878,10 +2881,10 @@ def bilinear_zpk(z, p, k, fs):
28782881
p_z = (fs2 + p) / (fs2 - p)
28792882

28802883
# Any zeros that were at infinity get moved to the Nyquist frequency
2881-
z_z = append(z_z, -ones(degree))
2884+
z_z = xp.concat((z_z, -xp.ones(degree)))
28822885

28832886
# Compensate for gain change
2884-
k_z = k * real(prod(fs2 - z) / prod(fs2 - p))
2887+
k_z = k * xp.real(xp.prod(fs2 - z) / xp.prod(fs2 - p))
28852888

28862889
return z_z, p_z, k_z
28872890

scipy/signal/tests/test_filter_design.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,19 +1614,25 @@ def test_basic(self, xp):
16141614

16151615
class TestBilinear_zpk:
16161616

1617-
def test_basic(self):
1618-
z = [-2j, +2j]
1619-
p = [-0.75, -0.5-0.5j, -0.5+0.5j]
1617+
def test_basic(self, xp):
1618+
z = xp.asarray([-2j, +2j])
1619+
p = xp.asarray([-0.75, -0.5-0.5j, -0.5+0.5j])
16201620
k = 3
16211621

16221622
z_d, p_d, k_d = bilinear_zpk(z, p, k, 10)
16231623

1624-
xp_assert_close(sort(z_d), sort([(20-2j)/(20+2j), (20+2j)/(20-2j),
1625-
-1]))
1626-
xp_assert_close(sort(p_d), sort([77/83,
1627-
(1j/2 + 39/2) / (41/2 - 1j/2),
1628-
(39/2 - 1j/2) / (1j/2 + 41/2), ]))
1629-
xp_assert_close(k_d, 9696/69803)
1624+
xp_assert_close(
1625+
_sort_cmplx(z_d, xp=xp),
1626+
_sort_cmplx([(20-2j) / (20+2j), (20+2j) / (20-2j), -1], xp=xp)
1627+
)
1628+
xp_assert_close(
1629+
_sort_cmplx(p_d, xp=xp),
1630+
_sort_cmplx(
1631+
[77/83, (1j/2 + 39/2) / (41/2 - 1j/2), (39/2 - 1j/2) / (1j/2 + 41/2)],
1632+
xp=xp
1633+
)
1634+
)
1635+
assert math.isclose(k_d, 9696/69803)
16301636

16311637

16321638
class TestPrototypeType:

0 commit comments

Comments
 (0)