|
| 1 | + |
| 2 | +""" Classes and functions for fitting ivim model """ |
| 3 | +import numpy as np |
| 4 | +from scipy.optimize import lsq_linear |
| 5 | +from dipy.reconst.base import ReconstModel |
| 6 | +from dipy.reconst.multi_voxel import multi_voxel_fit |
| 7 | +from dipy.utils.optpkg import optional_package |
| 8 | +from scipy.signal import unit_impulse |
| 9 | + |
| 10 | + |
| 11 | +class IvimModelLinear(ReconstModel): |
| 12 | + |
| 13 | + def __init__(self, gtab, b_threshold=200, bounds=None, rescale_units=False): |
| 14 | + """A simple nlls fit to the bi-exponential IVIM model. No segmentations |
| 15 | + are performed. |
| 16 | +
|
| 17 | + Args: |
| 18 | + gtab (DIPY gradient table): |
| 19 | + DIPY gradient table object containing |
| 20 | + information of the diffusion gradients, b-values, etc. |
| 21 | + |
| 22 | + bounds (array-like, optional): |
| 23 | + Bounds expressed as [lower bounds, upper bounds] for S0, f, D*, and |
| 24 | + D respectively. Defaults to None. |
| 25 | + |
| 26 | + initial_guess (array-like, optional): |
| 27 | + The initial guess for the parameters. Defaults to None. |
| 28 | + |
| 29 | + rescale_units (bool, optional): |
| 30 | + Set to True if parameters are to be returned in units of µm2/ms. |
| 31 | + The conversion only works in one direction, from mm2/s to µm2/ms. |
| 32 | + Make sure the b-values in the gtab object are already in units of |
| 33 | + µm2/ms if this is used. Defaults to False. |
| 34 | + """ |
| 35 | + |
| 36 | + self.b_threshold = b_threshold |
| 37 | + self.bvals = gtab.bvals[gtab.bvals >= self.b_threshold] |
| 38 | + |
| 39 | + # Get the indices for the b-values that fulfils the condition. |
| 40 | + # Will be used to get the corresponding signals. |
| 41 | + b_threshold_idx = np.where(self.bvals >= self.b_threshold)[0][1] |
| 42 | + self.signal_indices = list(np.where(gtab.bvals >= self.b_threshold)[0]) |
| 43 | + |
| 44 | + self.set_bounds(bounds) # Sets the bounds according to the requirements of the fits |
| 45 | + self.rescale_bounds_and_initial_guess(rescale_units) # Rescales the units of D* and D to µm2/ms if set to True |
| 46 | + |
| 47 | + |
| 48 | + @multi_voxel_fit |
| 49 | + def fit(self, data): |
| 50 | + # Normalize the data and move to the logarithmic space |
| 51 | + data_max = data.max() |
| 52 | + if data_max == 0: |
| 53 | + pass |
| 54 | + else: |
| 55 | + data_log = np.log(data / data_max) |
| 56 | + |
| 57 | + # Sort out the signals from non-zero b-values < b-threshold |
| 58 | + ydata = data_log[self.signal_indices] |
| 59 | + |
| 60 | + # Define the design matrix |
| 61 | + A = np.vstack([self.bvals, np.ones(len(self.bvals))]).T |
| 62 | + |
| 63 | + # Get the bounds for D and f |
| 64 | + lsq_bounds_lower = (-self.bounds[1][2], -self.bounds[1][1]) |
| 65 | + lsq_bounds_upper = (-self.bounds[0][2], -self.bounds[0][1]) |
| 66 | + lsq_bounds = (lsq_bounds_lower, lsq_bounds_upper) |
| 67 | + |
| 68 | + # Perform the fit |
| 69 | + popt = lsq_linear(A, ydata, bounds=lsq_bounds).x |
| 70 | + D, f = -popt # f is estimated as the negative of the intercept |
| 71 | + |
| 72 | + |
| 73 | + # Set the results and rescale S0 |
| 74 | + result = np.array([data[0], f, D]) |
| 75 | + result[0] *= data_max |
| 76 | + |
| 77 | + return IvimFit(self, result) |
| 78 | + |
| 79 | + def sivim_model(self, b, S0, f, D): |
| 80 | + delta = unit_impulse(b.shape, idx=0) |
| 81 | + res = S0*(f*delta + (1-f)*np.exp(-b*D)) |
| 82 | + return res |
| 83 | + |
| 84 | + def set_bounds(self, bounds): |
| 85 | + # Use this function for fits that uses curve_fit |
| 86 | + if bounds == None: |
| 87 | + self.bounds = np.array([(0, 0, 0), (np.inf, 1, 0.004)]) |
| 88 | + else: |
| 89 | + self.bounds = np.array([(0, *bounds[0]), (np.inf, *bounds[1])]) |
| 90 | + |
| 91 | + def set_initial_guess(self, initial_guess): |
| 92 | + if initial_guess == None: |
| 93 | + self.initial_guess = (1, 0.2, 0.001) |
| 94 | + else: |
| 95 | + self.initial_guess = initial_guess |
| 96 | + |
| 97 | + def rescale_bounds_and_initial_guess(self, rescale_units): |
| 98 | + if rescale_units: |
| 99 | + # Rescale the bounds |
| 100 | + lower_bounds = (self.bounds[0][0], self.bounds[0][1], \ |
| 101 | + self.bounds[0][2]*1000) |
| 102 | + upper_bounds = (self.bounds[1][0], self.bounds[1][1], \ |
| 103 | + self.bounds[1][2]*1000) |
| 104 | + self.bounds = (lower_bounds, upper_bounds) |
| 105 | + |
| 106 | + |
| 107 | + |
| 108 | +class IvimFit(object): |
| 109 | + |
| 110 | + def __init__(self, model, model_params): |
| 111 | + """ Initialize a IvimFit class instance. |
| 112 | + Parameters |
| 113 | + ---------- |
| 114 | + model : Model class |
| 115 | + model_params : array |
| 116 | + The parameters of the model. In this case it is an |
| 117 | + array of ivim parameters. If the fitting is done |
| 118 | + for multi_voxel data, the multi_voxel decorator will |
| 119 | + run the fitting on all the voxels and model_params |
| 120 | + will be an array of the dimensions (data[:-1], 4), |
| 121 | + i.e., there will be 4 parameters for each of the voxels. |
| 122 | + """ |
| 123 | + self.model = model |
| 124 | + self.model_params = model_params |
| 125 | + |
| 126 | + def __getitem__(self, index): |
| 127 | + model_params = self.model_params |
| 128 | + N = model_params.ndim |
| 129 | + if type(index) is not tuple: |
| 130 | + index = (index,) |
| 131 | + elif len(index) >= model_params.ndim: |
| 132 | + raise IndexError("IndexError: invalid index") |
| 133 | + index = index + (slice(None),) * (N - len(index)) |
| 134 | + return type(self)(self.model, model_params[index]) |
| 135 | + |
| 136 | + @property |
| 137 | + def S0_predicted(self): |
| 138 | + return self.model_params[..., 0] |
| 139 | + |
| 140 | + @property |
| 141 | + def perfusion_fraction(self): |
| 142 | + return self.model_params[..., 1] |
| 143 | + |
| 144 | + #@property |
| 145 | + #def D_star(self): |
| 146 | + #return self.model_params[..., 2] |
| 147 | + |
| 148 | + @property |
| 149 | + def D(self): |
| 150 | + return self.model_params[..., 3] |
| 151 | + |
| 152 | + @property |
| 153 | + def shape(self): |
| 154 | + return self.model_params.shape[:-1] |
0 commit comments