Skip to content

Commit 6f7ee75

Browse files
committed
MAINT: signal: convert lp2{lp,bp,hp}_zpk to be array api compatible
1 parent d2e491c commit 6f7ee75

File tree

2 files changed

+108
-64
lines changed

2 files changed

+108
-64
lines changed

scipy/signal/_filter_design.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,7 +2794,7 @@ def _relative_degree(z, p):
27942794
"""
27952795
Return relative degree of transfer function from zeros and poles
27962796
"""
2797-
degree = len(p) - len(z)
2797+
degree = p.shape[0] - z.shape[0]
27982798
if degree < 0:
27992799
raise ValueError("Improper transfer function. "
28002800
"Must have at least as many poles as zeros.")
@@ -2941,8 +2941,12 @@ def lp2lp_zpk(z, p, k, wo=1.0):
29412941
>>> lp2lp_zpk(z, p, k, wo)
29422942
( array([2.8, 0.8]), array([2. , 5.2]), 0.8)
29432943
"""
2944-
z = atleast_1d(z)
2945-
p = atleast_1d(p)
2944+
xp = array_namespace(z, p)
2945+
2946+
z, p = map(xp.asarray, (z, p))
2947+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
2948+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
2949+
29462950
wo = float(wo) # Avoid int wraparound
29472951

29482952
degree = _relative_degree(z, p)
@@ -3018,8 +3022,12 @@ def lp2hp_zpk(z, p, k, wo=1.0):
30183022
array([-0.6 , -0.15]),
30193023
8.5)
30203024
"""
3021-
z = atleast_1d(z)
3022-
p = atleast_1d(p)
3025+
xp = array_namespace(z, p)
3026+
3027+
z, p = map(xp.asarray, (z, p))
3028+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
3029+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
3030+
30233031
wo = float(wo)
30243032

30253033
degree = _relative_degree(z, p)
@@ -3030,10 +3038,10 @@ def lp2hp_zpk(z, p, k, wo=1.0):
30303038
p_hp = wo / p
30313039

30323040
# If lowpass had zeros at infinity, inverting moves them to origin.
3033-
z_hp = append(z_hp, zeros(degree))
3041+
z_hp = xp.concat((z_hp, xp.zeros(degree)))
30343042

30353043
# Cancel out gain change caused by inversion
3036-
k_hp = k * real(prod(-z) / prod(-p))
3044+
k_hp = k * xp.real(xp.prod(-z) / xp.prod(-p))
30373045

30383046
return z_hp, p_hp, k_hp
30393047

@@ -3104,8 +3112,12 @@ def lp2bp_zpk(z, p, k, wo=1.0, bw=1.0):
31043112
array([1.04996339e+02+0.j, -1.60167736e-03+0.j, 3.66108003e-03+0.j,
31053113
-2.39998398e+02+0.j]), 0.8)
31063114
"""
3107-
z = atleast_1d(z)
3108-
p = atleast_1d(p)
3115+
xp = array_namespace(z, p)
3116+
3117+
z, p = map(xp.asarray, (z, p))
3118+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
3119+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
3120+
31093121
wo = float(wo)
31103122
bw = float(bw)
31113123

@@ -3116,17 +3128,17 @@ def lp2bp_zpk(z, p, k, wo=1.0, bw=1.0):
31163128
p_lp = p * bw/2
31173129

31183130
# Square root needs to produce complex result, not NaN
3119-
z_lp = z_lp.astype(complex)
3120-
p_lp = p_lp.astype(complex)
3131+
z_lp = xp.astype(z_lp, xp.complex128)
3132+
p_lp = xp.astype(p_lp, xp.complex128)
31213133

31223134
# Duplicate poles and zeros and shift from baseband to +wo and -wo
3123-
z_bp = concatenate((z_lp + sqrt(z_lp**2 - wo**2),
3124-
z_lp - sqrt(z_lp**2 - wo**2)))
3125-
p_bp = concatenate((p_lp + sqrt(p_lp**2 - wo**2),
3126-
p_lp - sqrt(p_lp**2 - wo**2)))
3135+
z_bp = xp.concat((z_lp + xp.sqrt(z_lp**2 - wo**2),
3136+
z_lp - xp.sqrt(z_lp**2 - wo**2)))
3137+
p_bp = xp.concat((p_lp + xp.sqrt(p_lp**2 - wo**2),
3138+
p_lp - xp.sqrt(p_lp**2 - wo**2)))
31273139

31283140
# Move degree zeros to origin, leaving degree zeros at infinity for BPF
3129-
z_bp = append(z_bp, zeros(degree))
3141+
z_bp = xp.concat((z_bp, xp.zeros(degree)))
31303142

31313143
# Cancel out gain change from frequency scaling
31323144
k_bp = k * bw**degree
@@ -3199,8 +3211,12 @@ def lp2bs_zpk(z, p, k, wo=1.0, bw=1.0):
31993211
array([14.2681928 +0.j, -0.02506281+0.j, 0.01752149+0.j, -9.97493719+0.j]),
32003212
-12.857142857142858)
32013213
"""
3202-
z = atleast_1d(z)
3203-
p = atleast_1d(p)
3214+
xp = array_namespace(z, p)
3215+
3216+
z, p = map(xp.asarray, (z, p))
3217+
z = xpx.atleast_nd(z, ndim=1, xp=xp)
3218+
p = xpx.atleast_nd(p, ndim=1, xp=xp)
3219+
32043220
wo = float(wo)
32053221
bw = float(bw)
32063222

@@ -3211,21 +3227,21 @@ def lp2bs_zpk(z, p, k, wo=1.0, bw=1.0):
32113227
p_hp = (bw/2) / p
32123228

32133229
# Square root needs to produce complex result, not NaN
3214-
z_hp = z_hp.astype(complex)
3215-
p_hp = p_hp.astype(complex)
3230+
z_hp = xp.astype(z_hp, xp.complex128)
3231+
p_hp = xp.astype(p_hp, xp.complex128)
32163232

32173233
# Duplicate poles and zeros and shift from baseband to +wo and -wo
3218-
z_bs = concatenate((z_hp + sqrt(z_hp**2 - wo**2),
3219-
z_hp - sqrt(z_hp**2 - wo**2)))
3220-
p_bs = concatenate((p_hp + sqrt(p_hp**2 - wo**2),
3221-
p_hp - sqrt(p_hp**2 - wo**2)))
3234+
z_bs = xp.concat((z_hp + xp.sqrt(z_hp**2 - wo**2),
3235+
z_hp - xp.sqrt(z_hp**2 - wo**2)))
3236+
p_bs = xp.concat((p_hp + xp.sqrt(p_hp**2 - wo**2),
3237+
p_hp - xp.sqrt(p_hp**2 - wo**2)))
32223238

32233239
# Move any zeros that were at infinity to the center of the stopband
3224-
z_bs = append(z_bs, full(degree, +1j*wo))
3225-
z_bs = append(z_bs, full(degree, -1j*wo))
3240+
z_bs = xp.concat((z_bs, xp.full(degree, +1j*wo)))
3241+
z_bs = xp.concat((z_bs, xp.full(degree, -1j*wo)))
32263242

32273243
# Cancel out gain change caused by inversion
3228-
k_bs = k * real(prod(-z) / prod(-p))
3244+
k_bs = k * xp.real(xp.prod(-z) / xp.prod(-p))
32293245

32303246
return z_bs, p_bs, k_bs
32313247

scipy/signal/tests/test_filter_design.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import cmath
23
import warnings
34

45
from itertools import product
@@ -1491,25 +1492,42 @@ def test_fs_validation(self):
14911492
bilinear(b, a, fs=None)
14921493

14931494

1495+
def _sort_cmplx(arr, xp):
1496+
# xp.sort is undefined for complex dtypes. Here we only need some
1497+
# consistent way to sort a complex array, including equal magnitude elements.
1498+
arr = xp.asarray(arr)
1499+
if xp.isdtype(arr.dtype, 'complex floating'):
1500+
sorter = abs(arr) + xp.real(arr) + xp.imag(arr)**3
1501+
else:
1502+
sorter = arr
1503+
1504+
idxs = xp.argsort(sorter)
1505+
return arr[idxs]
1506+
1507+
14941508
class TestLp2lp_zpk:
14951509

1496-
def test_basic(self):
1497-
z = []
1498-
p = [(-1+1j)/np.sqrt(2), (-1-1j)/np.sqrt(2)]
1510+
def test_basic(self, xp):
1511+
z = xp.asarray([])
1512+
p = xp.asarray([(-1+1j) / math.sqrt(2), (-1-1j) / math.sqrt(2)])
14991513
k = 1
15001514
z_lp, p_lp, k_lp = lp2lp_zpk(z, p, k, 5)
1501-
xp_assert_equal(z_lp, [])
1502-
xp_assert_close(sort(p_lp), sort(p)*5)
1503-
xp_assert_close(k_lp, 25.)
1515+
xp_assert_equal(z_lp, xp.asarray([]))
1516+
xp_assert_close(_sort_cmplx(p_lp, xp=xp), _sort_cmplx(p, xp=xp) * 5)
1517+
assert k_lp == 25.
15041518

15051519
# Pseudo-Chebyshev with both poles and zeros
1506-
z = [-2j, +2j]
1507-
p = [-0.75, -0.5-0.5j, -0.5+0.5j]
1520+
z = xp.asarray([-2j, +2j])
1521+
p = xp.asarray([-0.75, -0.5-0.5j, -0.5+0.5j])
15081522
k = 3
15091523
z_lp, p_lp, k_lp = lp2lp_zpk(z, p, k, 20)
1510-
xp_assert_close(sort(z_lp), sort([-40j, +40j]))
1511-
xp_assert_close(sort(p_lp), sort([-15, -10-10j, -10+10j]))
1512-
xp_assert_close(k_lp, 60.)
1524+
xp_assert_close(
1525+
_sort_cmplx(z_lp, xp=xp), _sort_cmplx([-40j, +40j], xp=xp)
1526+
)
1527+
xp_assert_close(
1528+
_sort_cmplx(p_lp, xp=xp), _sort_cmplx([-15, -10-10j, -10+10j], xp=xp)
1529+
)
1530+
assert k_lp == 60.
15131531

15141532
def test_fs_validation(self):
15151533
z = [-2j, +2j]
@@ -1525,23 +1543,27 @@ def test_fs_validation(self):
15251543

15261544
class TestLp2hp_zpk:
15271545

1528-
def test_basic(self):
1529-
z = []
1530-
p = [(-1+1j)/np.sqrt(2), (-1-1j)/np.sqrt(2)]
1546+
def test_basic(self, xp):
1547+
z = xp.asarray([])
1548+
p = xp.asarray([(-1+1j) / math.sqrt(2), (-1-1j) / math.sqrt(2)])
15311549
k = 1
15321550

15331551
z_hp, p_hp, k_hp = lp2hp_zpk(z, p, k, 5)
1534-
xp_assert_equal(z_hp, np.asarray([0.0, 0.0]))
1535-
xp_assert_close(sort(p_hp), sort(p)*5)
1536-
xp_assert_close(k_hp, 1.0)
1552+
xp_assert_equal(z_hp, xp.asarray([0.0, 0.0]))
1553+
xp_assert_close(_sort_cmplx(p_hp, xp=xp), _sort_cmplx(p, xp=xp) * 5)
1554+
assert math.isclose(k_hp, 1.0)
15371555

1538-
z = [-2j, +2j]
1539-
p = [-0.75, -0.5-0.5j, -0.5+0.5j]
1556+
z = xp.asarray([-2j, +2j])
1557+
p = xp.asarray([-0.75, -0.5-0.5j, -0.5+0.5j])
15401558
k = 3
15411559
z_hp, p_hp, k_hp = lp2hp_zpk(z, p, k, 6)
1542-
xp_assert_close(sort(z_hp), sort([-3j, 0, +3j]))
1543-
xp_assert_close(sort(p_hp), sort([-8, -6-6j, -6+6j]))
1544-
xp_assert_close(k_hp, 32.0)
1560+
xp_assert_close(
1561+
_sort_cmplx(z_hp, xp=xp), _sort_cmplx([-3j, 0, +3j], xp=xp)
1562+
)
1563+
xp_assert_close(
1564+
_sort_cmplx(p_hp, xp=xp), _sort_cmplx([-8, -6-6j, -6+6j], xp=xp)
1565+
)
1566+
assert k_hp == 32.0
15451567

15461568

15471569
class TestLp2bp_zpk:
@@ -1563,25 +1585,31 @@ def test_basic(self):
15631585

15641586
class TestLp2bs_zpk:
15651587

1566-
def test_basic(self):
1567-
z = [-2j, +2j]
1568-
p = [-0.75, -0.5-0.5j, -0.5+0.5j]
1588+
def test_basic(self, xp):
1589+
z = xp.asarray([-2j, +2j])
1590+
p = xp.asarray([-0.75, -0.5-0.5j, -0.5+0.5j])
15691591
k = 3
15701592

15711593
z_bs, p_bs, k_bs = lp2bs_zpk(z, p, k, 35, 12)
15721594

1573-
xp_assert_close(sort(z_bs), sort([+35j, -35j,
1574-
+3j+sqrt(1234)*1j,
1575-
-3j+sqrt(1234)*1j,
1576-
+3j-sqrt(1234)*1j,
1577-
-3j-sqrt(1234)*1j]))
1578-
xp_assert_close(sort(p_bs), sort([+3j*sqrt(129) - 8,
1579-
-3j*sqrt(129) - 8,
1580-
(-6 + 6j) - sqrt(-1225 - 72j),
1581-
(-6 - 6j) - sqrt(-1225 + 72j),
1582-
(-6 + 6j) + sqrt(-1225 - 72j),
1583-
(-6 - 6j) + sqrt(-1225 + 72j), ]))
1584-
xp_assert_close(k_bs, 32.0)
1595+
xp_assert_close(
1596+
_sort_cmplx(z_bs, xp=xp),
1597+
_sort_cmplx([+35j, -35j,
1598+
+3j + math.sqrt(1234)*1j,
1599+
-3j + math.sqrt(1234)*1j,
1600+
+3j - math.sqrt(1234)*1j,
1601+
-3j - math.sqrt(1234)*1j], xp=xp)
1602+
)
1603+
xp_assert_close(
1604+
_sort_cmplx(p_bs, xp=xp),
1605+
_sort_cmplx([+3j*math.sqrt(129) - 8,
1606+
-3j*math.sqrt(129) - 8,
1607+
(-6 + 6j) - cmath.sqrt(-1225 - 72j),
1608+
(-6 - 6j) - cmath.sqrt(-1225 + 72j),
1609+
(-6 + 6j) + cmath.sqrt(-1225 - 72j),
1610+
(-6 - 6j) + cmath.sqrt(-1225 + 72j), ], xp=xp)
1611+
)
1612+
assert math.isclose(k_bs, 32.0)
15851613

15861614

15871615
class TestBilinear_zpk:

0 commit comments

Comments
 (0)