Skip to content

Commit 27dfa3f

Browse files
committed
add utils and requirements
1 parent a8ffa98 commit 27dfa3f

File tree

13 files changed

+485
-0
lines changed

13 files changed

+485
-0
lines changed

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
numpy
2+
torchio
3+
torch
4+
os
5+
logging
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import logging
3+
import torch
4+
import numpy as np
5+
import torchio as tio
6+
7+
8+
def load_ivim_subject(study_subject_path):
9+
"""
10+
loads torchio subject with IVIM data (signals and bvalues)
11+
Args:
12+
study_subject_path: path in which subject data is located
13+
14+
Returns:
15+
16+
"""
17+
subject_dict = {}
18+
19+
# find all files that match study path and subject id
20+
for file in os.listdir(study_subject_path):
21+
file_path = os.path.join(study_subject_path, file)
22+
logging.info(f'start loading data from {file_path}')
23+
24+
# Check file extension for image file
25+
if file_path[-2:] == "gz" or file_path[-2:] == "ii":
26+
27+
# load nifti image
28+
image = tio.Image(file_path)
29+
image.set_data(image.data.to(dtype=torch.float32))
30+
subject_dict['signals'] = image
31+
32+
# Check if file contains bvalues
33+
elif file_path[-2:] == "al":
34+
text_file = np.genfromtxt(file_path)
35+
bvals = np.array(text_file)
36+
subject_dict["xvals"] = tio.Image(tensor=torch.Tensor(np.reshape(bvals, (bvals.shape[0], 1, 1, 1))))
37+
38+
else:
39+
logging.info(f'skipping loading of file {file_path}, no appropriate file extension. ')
40+
41+
# Create subject
42+
if 'xvals' in subject_dict.keys() and 'signals' in subject_dict.keys():
43+
return tio.Subject(subject_dict)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import torch
3+
import torchio
4+
from torchio.transforms import Transform
5+
6+
7+
class AverageSignalsOfEqualXvals(Transform):
8+
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
def apply_transform(self, subject):
13+
"""
14+
normalize signals
15+
Args:
16+
signals: signals array to normalize
17+
xvals: xval array
18+
19+
Returns:
20+
normalized_signals: normalized signals array
21+
22+
"""
23+
images_dict = subject.get_images_dict()
24+
signals = images_dict['signals'].numpy()
25+
xvals = np.squeeze(images_dict['xvals'].numpy())
26+
signals, xvals = self.average_signal_of_equal_xvals(signals, xvals)
27+
subject.add_image(torchio.Image(tensor=torch.Tensor(signals)), 'signals')
28+
subject.add_image(torchio.Image(tensor=torch.Tensor(np.reshape(xvals, (xvals.shape[0], 1, 1, 1)))), 'xvals')
29+
return subject
30+
31+
@staticmethod
32+
def average_signal_of_equal_xvals(signals, xvals):
33+
"""
34+
average the signal of equal xvals
35+
Args:
36+
signals: signal matrix [signals X xvals]
37+
xvals: array of xvals
38+
39+
Returns:
40+
averaged_signal_matrix: averaged signal matrix [signals X unique_xvals]
41+
unique_xval_arrays: unique xvals in averaged signal matrix
42+
43+
"""
44+
unique_xvals = np.unique(xvals)
45+
averaged_signals = np.zeros((unique_xvals.shape[0], *signals.shape[1:]))
46+
for xval_idx, unique_xval in enumerate(unique_xvals):
47+
averaged_signals[xval_idx, ...] = np.squeeze(np.mean(signals[np.where(xvals == unique_xval), ...], axis=1))
48+
return averaged_signals, unique_xvals
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import torch
3+
import torchio
4+
5+
from torchio.transforms import Transform
6+
7+
8+
class FlattenImageData(Transform):
9+
10+
def __init__(self, **kwargs):
11+
super().__init__(**kwargs)
12+
13+
def apply_transform(self, subject):
14+
"""
15+
flattens image data of signals image of subject
16+
Args:
17+
signals: signals array to normalize
18+
xvals: xval array
19+
20+
Returns:
21+
normalized_signals: normalized signals array
22+
23+
"""
24+
images_dict = subject.get_images_dict(include=self.include, exclude=self.exclude)
25+
for image_key, image in images_dict.items():
26+
flattened_array = self.flatten_image_data(image.numpy())
27+
subject.add_image(torchio.Image(tensor=torch.Tensor(np.reshape(flattened_array,
28+
(flattened_array.shape[0],
29+
flattened_array.shape[1], 1, 1)))),
30+
image_key)
31+
return subject
32+
33+
@staticmethod
34+
def flatten_image_data(signals):
35+
"""
36+
Flattens 4D array into 2D array
37+
Args:
38+
signals: signals array to normalize
39+
40+
Returns:
41+
normalized_signals: normalized signals array
42+
"""
43+
bvals, x, y, z = signals.shape
44+
signals_array = np.reshape(signals, (bvals, x * y * z))
45+
return signals_array
46+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
import torch
3+
import torchio
4+
from torchio.transforms import Transform
5+
6+
7+
class NormalizeMaxSignal(Transform):
8+
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
def apply_transform(self, subject):
13+
"""
14+
normalize signals Image of subject
15+
Args:
16+
subject: Subject
17+
18+
Returns:
19+
subject: Subject
20+
"""
21+
images_dict = subject.get_images_dict()
22+
signals = images_dict['signals'].numpy()
23+
signals = self.normalize_signals(signals)
24+
subject.add_image(torchio.Image(tensor=torch.Tensor(signals)), 'signals')
25+
return subject
26+
27+
@staticmethod
28+
def normalize_signals(signals):
29+
"""
30+
normalize signals
31+
Args:
32+
signals: signals array to normalize
33+
34+
Returns:
35+
normalized_signals: normalized signals array
36+
37+
"""
38+
maxsignal = np.nanmax(signals, axis=0)
39+
signals /= maxsignal
40+
41+
return signals
42+
43+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
import torch
3+
import torchio
4+
from torchio.transforms import Transform
5+
6+
7+
class NormalizeSignals(Transform):
8+
9+
def __init__(self, xval_threshold, **kwargs):
10+
self.xval_threshold = xval_threshold
11+
super().__init__(**kwargs)
12+
13+
def apply_transform(self, subject):
14+
"""
15+
normalize xvals Image of subject
16+
Args:
17+
subject: Subject
18+
19+
Returns:
20+
subject: Subject
21+
"""
22+
images_dict = subject.get_images_dict()
23+
signals = images_dict['signals'].numpy()
24+
xvals = np.squeeze(images_dict['xvals'].numpy())
25+
signals = self.normalize_signals(signals, xvals, self.xval_threshold)
26+
subject.add_image(torchio.Image(tensor=torch.Tensor(signals)), 'signals')
27+
return subject
28+
29+
@staticmethod
30+
def normalize_signals(signals, xvals, xval_threshold):
31+
"""
32+
normalize signals
33+
Args:
34+
signals: signals array to normalize
35+
xvals: xval array
36+
xval_threshold: threshold below which bvals are considered b0
37+
38+
Returns:
39+
normalized_signals: normalized signals array
40+
41+
"""
42+
# get average b0 signals and set signals with S0 of 0 to nan
43+
mean_S0 = np.nanmean(signals[xvals <= xval_threshold, :, :, :], axis=0)
44+
signals[:, mean_S0 == 0] = np.nan
45+
46+
# normalize signals to S0 intensity
47+
normalized_signals = signals / mean_S0
48+
normalized_signals[np.isnan(normalized_signals)] = 0
49+
50+
return normalized_signals
51+
52+
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torchio
3+
import numpy as np
4+
5+
from torchio.transforms import Transform
6+
7+
8+
class NormalizeXvals(Transform):
9+
10+
def __init__(self, normalization_factor, **kwargs):
11+
self.normalization_factor = normalization_factor
12+
super().__init__(**kwargs)
13+
14+
def apply_transform(self, subject):
15+
"""
16+
normalize xvals Image of subject
17+
Args:
18+
subject: Subject
19+
20+
Returns:
21+
subject: Subject
22+
"""
23+
images_dict = subject.get_images_dict()
24+
xvals = np.squeeze(images_dict['xvals'].numpy())
25+
xvals = self.normalize_xvals(xvals, self.normalization_factor)
26+
subject.add_image(torchio.Image(tensor=torch.Tensor(np.reshape(xvals, (xvals.shape[0], 1, 1, 1)))), 'xvals')
27+
return subject
28+
29+
@staticmethod
30+
def normalize_xvals(xvals, normalization_factor):
31+
"""
32+
normalize signals
33+
Args:
34+
xvals: xvalue array
35+
normalization_factor: factor to multiply xvals with
36+
37+
Returns:
38+
normalized_xvals: normalized signals array
39+
"""
40+
normalized_xvals = xvals * normalization_factor
41+
return normalized_xvals
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy as np
2+
import torch
3+
import torchio
4+
from torchio.transforms import Transform
5+
6+
7+
class SignalCuration(Transform):
8+
9+
def __init__(self, qmri_application, **kwargs):
10+
self.qmri_application = qmri_application
11+
super().__init__(**kwargs)
12+
13+
def apply_transform(self, subject):
14+
"""
15+
curates signals Image of subject
16+
Args:
17+
subject: Subject
18+
19+
Returns:
20+
subject: Subject
21+
"""
22+
images_dict = subject.get_images_dict()
23+
if self.qmri_application == 'IVIM' or 'ivim':
24+
signals = images_dict['signals'].numpy()
25+
xvals = np.squeeze(images_dict['xvals'].numpy())
26+
valid_mask = self.ivim_selection(signals, xvals)
27+
subject.add_image(torchio.Image(tensor=torch.Tensor(np.expand_dims(valid_mask, 0))), 'valid_mask')
28+
return subject
29+
else:
30+
raise NotImplementedError
31+
32+
@staticmethod
33+
def ivim_selection(signals, xvals):
34+
"""
35+
returns only those signals exhibiting ivim decay
36+
Args:
37+
signals: signals for corresponding xvals
38+
xvals: xvals
39+
40+
Returns:
41+
normalized_valid_signals: normalized_valid_signals that exhibit ivim-like decay
42+
masked_signals: normalized_signals where signals not exhibiting ivim-like decay are set to 0
43+
"""
44+
45+
# get average b0 signals and set signals with S0 of 0 to nan
46+
mean_S0 = np.nanmean(signals[xvals <= 0.0001, ...], axis=0)
47+
48+
# select only those voxels with average S0 larger than half of median S0 of voxels with S0 larger than 0
49+
valid_idcs_median_value = mean_S0 > (0.5 * np.nanmedian(mean_S0[mean_S0 > 0]))
50+
51+
# check if signal is ivim like
52+
signals[signals > 1.5] = 1.5
53+
valid_idcs_ivim_curve1 = np.percentile(signals[xvals * 100 < 50, ...], 95,
54+
axis=0) < 1.3
55+
valid_idcs_ivim_curve2 = np.percentile(signals[xvals * 100 > 50, ...], 95,
56+
axis=0) < 1.2
57+
valid_idcs_ivim_curve3 = np.percentile(signals[xvals * 100 > 150, ...], 95,
58+
axis=0) < 1.0
59+
mask_signals = valid_idcs_median_value & valid_idcs_ivim_curve1 & valid_idcs_ivim_curve2 & valid_idcs_ivim_curve3
60+
61+
return mask_signals
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
import torch
3+
import torchio
4+
from torchio.transforms import Transform
5+
6+
7+
class SignalMask(Transform):
8+
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
def apply_transform(self, subject):
13+
"""
14+
curates signals Image of subject
15+
Args:
16+
subject: Subject
17+
18+
Returns:
19+
subject: Subject
20+
"""
21+
images_dict = subject.get_images_dict()
22+
signals = images_dict['signals'].numpy()
23+
signal_mask = self.signal_mask(signals)
24+
subject.add_image(torchio.Image(tensor=torch.Tensor(np.expand_dims(signal_mask, 0))), 'signal_mask')
25+
return subject
26+
27+
@staticmethod
28+
def signal_mask(signals):
29+
"""
30+
returns mask with nonzero element for signal vectors with nonzero entries
31+
Args:
32+
signals: signals
33+
34+
Returns:
35+
masked_signals: signals containing nonzero elements
36+
"""
37+
38+
return np.sum(signals, axis=0) > 0

0 commit comments

Comments
 (0)