Skip to content

Commit 0b860d8

Browse files
committed
Fix docstring and unit test.
1 parent 5fd01f4 commit 0b860d8

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

scipy/signal/_fir_filter_design.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def firwin(numtaps, cutoff, *, width=None, window='hamming', pass_zero=True,
320320
See Also
321321
--------
322322
firwin2
323+
firwin_2d
323324
firls
324325
minimum_phase
325326
remez
@@ -1283,7 +1284,7 @@ def minimum_phase(h: np.ndarray,
12831284
return h_minimum[:n_out]
12841285

12851286

1286-
def firwin_2d(hsize, window, *, fc=None, fs=2, circular=False,
1287+
def firwin_2d(hsize, window, *, fc=None, fs=2, circular=False,
12871288
pass_zero=True, scale=True):
12881289
"""
12891290
2D FIR filter design using the window method.
@@ -1315,25 +1316,25 @@ def firwin_2d(hsize, window, *, fc=None, fs=2, circular=False,
13151316
fs : float, optional
13161317
The sampling frequency of the signal. Default is 2.
13171318
circular : bool, optional
1318-
Whether to create a circularly symmetric 2-D window. Default is False.
1319-
pass_zero : This parameter is passed to the `firwin` function for each
1320-
scalar frequency axis.
1319+
Whether to create a circularly symmetric 2-D window. Default is ``False``.
1320+
pass_zero : {True, False, 'bandpass', 'lowpass', 'highpass', 'bandstop'}, optional
1321+
This parameter is directly passed to `firwin` for each scalar frequency axis.
13211322
Hence, if ``True``, the DC gain, i.e., the gain at frequency (0, 0), is 1.
13221323
If ``False``, the DC gain is 0 at frequency (0, 0) if `circular` is ``True``.
1323-
If `circular` is ``False`` the frequencies (0, f1) and (f0, 0) will have gain 0.
1324+
If `circular` is ``False`` the frequencies (0, f1) and (f0, 0) will
1325+
have gain 0.
13241326
It can also be a string argument for the desired filter type
13251327
(equivalent to ``btype`` in IIR design functions).
13261328
scale : bool, optional
1327-
This parameter is passed to the `firwin` function for
1328-
each scalar frequency axis.
1329+
This parameter is directly passed to `firwin` for each scalar frequency axis.
13291330
Set to ``True`` to scale the coefficients so that the frequency
13301331
response is exactly unity at a certain frequency on one frequency axis.
13311332
That frequency is either:
13321333
13331334
- 0 (DC) if the first passband starts at 0 (i.e. pass_zero is ``True``)
1334-
- `fs/2` (the Nyquist frequency) if the first passband ends at
1335-
`fs/2` (i.e the filter is a single band highpass filter);
1336-
center of first passband otherwise
1335+
- `fs`/2 (the Nyquist frequency) if the first passband ends at `fs`/2
1336+
(i.e., the filter is a single band highpass filter);
1337+
center of first passband otherwise
13371338
13381339
Returns
13391340
-------
@@ -1343,52 +1344,54 @@ def firwin_2d(hsize, window, *, fc=None, fs=2, circular=False,
13431344
Raises
13441345
------
13451346
ValueError
1346-
If `hsize` and `window` are not 2-element tuples or lists.
1347-
If `cutoff` is None when `circular` is True.
1348-
If `cutoff` is outside the range [0, fs / 2] and `circular` is False.
1349-
If any of the elements in `window` are not recognized.
1347+
- If `hsize` and `window` are not 2-element tuples or lists.
1348+
- If `cutoff` is None when `circular` is True.
1349+
- If `cutoff` is outside the range [0, `fs`/2] and `circular` is ``False``.
1350+
- If any of the elements in `window` are not recognized.
13501351
RuntimeError
13511352
If `firwin` fails to converge when designing the filter.
13521353
13531354
See Also
13541355
--------
1355-
scipy.signal.firwin, scipy.signal.get_window
1356+
firwin: FIR filter design using the window method for 1d arrays.
1357+
get_window: Return a window of a given length and type.
13561358
13571359
Examples
13581360
--------
1359-
Generate a 5x5 low-pass filter with cutoff frequency 0.1.
1361+
Generate a 5x5 low-pass filter with cutoff frequency 0.1:
13601362
13611363
>>> import numpy as np
13621364
>>> from scipy.signal import get_window
1363-
>>> from scipy.signal import fwind1
1365+
>>> from scipy.signal import firwin_2d
13641366
>>> hsize = (5, 5)
13651367
>>> window = (("kaiser", 5.0), ("kaiser", 5.0))
13661368
>>> fc = 0.1
1367-
>>> filter_2d = fwind1(hsize, window, fc=fc)
1369+
>>> filter_2d = firwin_2d(hsize, window, fc=fc)
13681370
>>> filter_2d
13691371
array([[0.00025366, 0.00401662, 0.00738617, 0.00401662, 0.00025366],
13701372
[0.00401662, 0.06360159, 0.11695714, 0.06360159, 0.00401662],
13711373
[0.00738617, 0.11695714, 0.21507283, 0.11695714, 0.00738617],
13721374
[0.00401662, 0.06360159, 0.11695714, 0.06360159, 0.00401662],
13731375
[0.00025366, 0.00401662, 0.00738617, 0.00401662, 0.00025366]])
13741376
1375-
Generate a circularly symmetric 5x5 low-pass filter with Hamming window.
1377+
Generate a circularly symmetric 5x5 low-pass filter with Hamming window:
13761378
1377-
>>> filter_2d = fwind1((5, 5), 'hamming', fc=fc, circular=True)
1379+
>>> filter_2d = firwin_2d((5, 5), 'hamming', fc=fc, circular=True)
13781380
>>> filter_2d
13791381
array([[-0.00020354, -0.00020354, -0.00020354, -0.00020354, -0.00020354],
13801382
[-0.00020354, 0.01506844, 0.09907658, 0.01506844, -0.00020354],
13811383
[-0.00020354, 0.09907658, -0.00020354, 0.09907658, -0.00020354],
13821384
[-0.00020354, 0.01506844, 0.09907658, 0.01506844, -0.00020354],
13831385
[-0.00020354, -0.00020354, -0.00020354, -0.00020354, -0.00020354]])
13841386
1385-
Plotting the generated 2D filters (optional).
1387+
Generate Plots comparing the product of two 1d filters with a circular
1388+
symmetric filter:
13861389
13871390
>>> import matplotlib.pyplot as plt
13881391
>>> hsize, fc = (50, 50), 0.05
13891392
>>> window = (("kaiser", 5.0), ("kaiser", 5.0))
1390-
>>> filter0_2d = fwind1(hsize, window, fc=fc)
1391-
>>> filter1_2d = fwind1((50, 50), 'hamming', fc=fc, circular=True)
1393+
>>> filter0_2d = firwin_2d(hsize, window, fc=fc)
1394+
>>> filter1_2d = firwin_2d((50, 50), 'hamming', fc=fc, circular=True)
13921395
...
13931396
>>> fg, (ax0, ax1) = plt.subplots(1, 2, tight_layout=True, figsize=(6.5, 3.5))
13941397
>>> ax0.set_title("Product of 2 Windows")

scipy/signal/tests/test_fir_filter_design.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytest import raises as assert_raises
88
import pytest
99

10-
from scipy.fft import fft, fft2, fftshift
10+
from scipy.fft import fft, fft2
1111
from scipy.special import sinc
1212
from scipy.signal import kaiser_beta, kaiser_atten, kaiserord, \
1313
firwin, firwin2, freqz, remez, firls, minimum_phase, \
@@ -694,18 +694,23 @@ def test_impulse_response(self):
694694
xp_assert_close(response[16:47, 16:47], expected_response, rtol=1e-5)
695695

696696
def test_frequency_response(self):
697+
"""Compare 1d and 2d frequency response. """
697698
hsize = (31, 31)
698-
window = ("hamming", "hamming")
699+
windows = ("hamming", "hamming")
699700
fc = 0.4
700-
taps = firwin_2d(hsize, window, fc=fc)
701+
taps_1d = firwin(numtaps=hsize[0], cutoff=fc, window=windows[0])
702+
taps_2d = firwin_2d(hsize, windows, fc=fc)
701703

702-
freq_response = fftshift(fft2(taps))
704+
f_resp_1d = fft(taps_1d)
705+
f_resp_2d = fft2(taps_2d)
703706

704-
magnitude = np.abs(freq_response)
705-
assert xp_assert_close(magnitude.max(), 1.0, atol=0.01), (
706-
f"Max magnitude is {magnitude.max()}"
707-
)
708-
assert magnitude.min() >= 0.0, f"Min magnitude is {magnitude.min()}"
707+
xp_assert_close(f_resp_2d[0, :], f_resp_1d,
708+
err_msg='DC Gain at (0, f1) is not unity!')
709+
xp_assert_close(f_resp_2d[:, 0], f_resp_1d,
710+
err_msg='DC Gain at (f0, 0) is not unity!')
711+
xp_assert_close(f_resp_2d, np.outer(f_resp_1d, f_resp_1d),
712+
atol=np.finfo(f_resp_2d.dtype).resolution,
713+
err_msg='2d frequency response is not product of 1d responses')
709714

710715
def test_symmetry(self):
711716
hsize = (51, 51)
@@ -745,7 +750,7 @@ def test_known_result(self):
745750
col_filter = firwin(hsize[1], cutoff=fc, window=window, fs=fs)
746751
known_result = np.outer(row_filter, col_filter)
747752

748-
taps = firwin_2d(hsize, (window, window), fc)
753+
taps = firwin_2d(hsize, (window, window), fc=fc)
749754
assert taps.shape == known_result.shape, (
750755
f"Shape mismatch: {taps.shape} vs {known_result.shape}"
751756
)

0 commit comments

Comments
 (0)