Skip to content

drf_sti speedup (#1) #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 154 additions & 103 deletions python/tools/drf_sti.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# The full license is in the LICENSE file, distributed with this software.
# ----------------------------------------------------------------------------
"""Create a spectral time intensity summary plot for a data set."""
"""Multithreaded variant by dsheen to handle much larger datasets smoothly 2025/06/04"""


import datetime
Expand All @@ -16,6 +17,7 @@
import sys
import time
import traceback
import multiprocessing

import dateutil
import digital_rf as drf
Expand Down Expand Up @@ -118,6 +120,137 @@ def __init__(self, opt):
ax = self.f.add_subplot(self.gridspec[n])
self.subplots.append(ax)

def generate_fft_stripe(self, start_sample):

samples_per_stripe = self.samples_per_stripe

if self.opt.verbose:
print(
"read vector : {0} {1} {2}".format(
self.channels, start_sample, samples_per_stripe
)
)

# if beamforming then load and beamform channels
if self.opt.beamform:
if self.opt.verbose:
print("beamform: {0}".format(self.opt.beamform))

if self.opt.beamform == "sum":
if len(self.opt.phases) != len(self.channels):
print(
"Number of phases must match number of channels for"
" beamforming."
)

data = np.zeros([samples_per_stripe], np.complex128)

for idx, c in enumerate(self.channels):
if self.opt.verbose:
print("beamform sum channel {0}".format(c))

channel = c
subchannel = int(self.subchannels[idx])

try:
dv = self.dio.read_vector(
start_sample,
samples_per_stripe,
channel,
subchannel,
)
except Exception:
# handle data gaps better
dv = np.empty(samples_per_stripe, np.complex64)
dv[:] = np.nan

dv = dv * np.exp(1j * np.deg2rad(float(self.opt.phases[idx])))

data = data + dv

else:
print("Unknown beamforming method {0}".format(self.opt.beamform))
return
else:
if self.opt.verbose:
print("Using channel {0}".format(self.channels[0]))
channel = self.channels[0]
subchannel = int(self.subchannels[0])

try:
data = self.dio.read_vector(
start_sample, samples_per_stripe, channel, subchannel
)
except IOError:
if self.opt.verbose:
print(
"IO Error for channel {0}:{1} start sample {2}".format(
channel,
subchannel,
start_sample,
)
)
# handle data gaps better
data = np.empty(samples_per_stripe, np.complex64)
data[:] = np.nan

if self.opt.decimation > 1:
data = scipy.signal.decimate(data, self.opt.decimation)
sample_freq = self.sr / self.opt.decimation
else:
sample_freq = self.sr

if self.opt.mean:
detrend_fn = "constant"
else:
detrend_fn = False
try:
freq_axis, psd_data = scipy.signal.welch(
data,
fs=float(sample_freq),
nperseg=self.opt.fft_bins,
detrend=detrend_fn,
scaling="spectrum",
return_onesided=False,
)
except Exception:
traceback.print_exc(file=sys.stdout)

sti_psd_data = np.real(
10.0 * np.log10(np.abs(scipy.fft.fftshift(psd_data)) + 1e-20)
) # 1e-20 is added to prevent divide by zero issues with logarithm

sti_time = start_sample / self.sr

return sti_psd_data, freq_axis, sti_time

def process_sti(self, start_samples):
# multithreaded sti processing

samples_per_stripe = self.samples_per_stripe

sti_psd_data = np.zeros([self.opt.fft_bins, len(start_samples)], np.float64)
sti_times = np.zeros([len(start_samples)], np.complex128)

if self.opt.num_processes == 0:
num_cores = multiprocessing.cpu_count()
else:
num_cores = np.minimum(multiprocessing.cpu_count(), self.opt.num_processes)

print(f"Using {num_cores} threads for STI calculation")

pool = multiprocessing.Pool()
pool = multiprocessing.Pool(processes=num_cores)

outputs = pool.map(self.generate_fft_stripe, start_samples)

for b in np.arange(len(start_samples), dtype=np.int_):
sti_psd_data[:, b] = outputs[b][0]
sti_times[b] = outputs[b][2]
freq_axis = outputs[b][1]

return [sti_psd_data, freq_axis, sti_times]

def plot(self):
"""Iterate over the data set and plot the STI into the subplot panels.

Expand Down Expand Up @@ -172,6 +305,9 @@ def plot(self):
samples_per_stripe = (
self.opt.fft_bins * self.opt.integration * self.opt.decimation
)

self.samples_per_stripe = samples_per_stripe

total_samples = blocks * samples_per_stripe

if total_samples > (et0 - st0):
Expand Down Expand Up @@ -209,111 +345,15 @@ def plot(self):
)

for p in np.arange(self.opt.frames):
sti_psd_data = np.zeros([self.opt.fft_bins, self.opt.length], np.float64)
sti_times = np.zeros([self.opt.length], np.complex128)

for b in np.arange(self.opt.length, dtype=np.int_):
if self.opt.verbose:
print(
"read vector : {0} {1} {2}".format(
self.channels, start_sample, samples_per_stripe
)
)

# if beamforming then load and beamform channels
if self.opt.beamform:
if self.opt.verbose:
print("beamform: {0}".format(self.opt.beamform))

if self.opt.beamform == "sum":
if len(self.opt.phases) != len(self.channels):
print(
"Number of phases must match number of channels for"
" beamforming."
)

data = np.zeros([samples_per_stripe], np.complex128)

for idx, c in enumerate(self.channels):
if self.opt.verbose:
print("beamform sum channel {0}".format(c))

channel = c
subchannel = int(self.subchannels[idx])

try:
dv = self.dio.read_vector(
start_sample,
samples_per_stripe,
channel,
subchannel,
)
except Exception:
# handle data gaps better
dv = np.empty(samples_per_stripe, np.complex64)
dv[:] = np.nan

dv = dv * np.exp(
1j * np.deg2rad(float(self.opt.phases[idx]))
)

data = data + dv

else:
print(
"Unknown beamforming method {0}".format(self.opt.beamform)
)
return
else:
if self.opt.verbose:
print("Using channel {0}".format(self.channels[0]))
channel = self.channels[0]
subchannel = int(self.subchannels[0])

try:
data = self.dio.read_vector(
start_sample, samples_per_stripe, channel, subchannel
)
except IOError:
if self.opt.verbose:
print(
"IO Error for channel {0}:{1} start sample {2}".format(
channel,
subchannel,
start_sample,
)
)
# handle data gaps better
data = np.empty(samples_per_stripe, np.complex64)
data[:] = np.nan

if self.opt.decimation > 1:
data = scipy.signal.decimate(data, self.opt.decimation)
sample_freq = self.sr / self.opt.decimation
else:
sample_freq = self.sr

if self.opt.mean:
detrend_fn = matplotlib.mlab.detrend_mean
else:
detrend_fn = matplotlib.mlab.detrend_none

try:
psd_data, freq_axis = matplotlib.mlab.psd(
data,
NFFT=self.opt.fft_bins,
Fs=float(sample_freq),
detrend=detrend_fn,
scale_by_freq=False,
)
except Exception:
traceback.print_exc(file=sys.stdout)

sti_psd_data[:, b] = np.real(10.0 * np.log10(np.abs(psd_data) + 1e-12))

sti_times[b] = start_sample / self.sr
start_samples = np.arange(
start_sample,
start_sample + stripe_stride * (self.opt.length - 1) + 1,
stripe_stride,
dtype=np.int_,
)

start_sample += stripe_stride
sti_psd_data, freq_axis, sti_times = self.process_sti(start_samples)

# Now Plot the Data
ax = self.subplots[p]
Expand Down Expand Up @@ -601,6 +641,17 @@ def parse_command_line():
" = needed due to argparse issue with negative numbers"
),
)
parser.add_argument(
"-P",
"--processes",
dest="num_processes",
default=1,
type=int,
help="""Number of processes to use for computing the STI.
If omitted defaults to 1 (single threaded).
setting processes to 0 will default to using a number of
processes equal to the number of available cpu cores""",
)
parser.add_argument(
"-o",
"--outname",
Expand Down
Loading