|
| 1 | +## Deep Active Lesion Segmention (DALS), Code by Ali Hatamizadeh ( http://web.cs.ucla.edu/~ahatamiz/ ) |
| 2 | + |
| 3 | +import os |
| 4 | +import numpy as np |
| 5 | +import tensorflow as tf |
| 6 | +from sklearn.metrics import f1_score |
| 7 | +from utils import load_image,my_func,resolve_status,contoured_image |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import argparse |
| 10 | +parser = argparse.ArgumentParser() |
| 11 | +parser.add_argument('--logdir', default='network_lung', type=str) |
| 12 | +parser.add_argument('--mu', default=0.2, type=float) |
| 13 | +parser.add_argument('--nu', default=5.0, type=float) |
| 14 | +parser.add_argument('--batch_size', default=1, type=int) |
| 15 | +parser.add_argument('--train_sum_freq', default=150, type=int) |
| 16 | +parser.add_argument('--train_iter', default=150000, type=int) |
| 17 | +parser.add_argument('--acm_iter_limit', default=300, type=int) |
| 18 | +parser.add_argument('--img_resize', default=512, type=int) |
| 19 | +parser.add_argument('--f_size', default=15, type=int) |
| 20 | +parser.add_argument('--train_status', default=1, type=int) |
| 21 | +parser.add_argument('--narrow_band_width', default=1, type=int) |
| 22 | +parser.add_argument('--save_freq', default=1000, type=int) |
| 23 | +parser.add_argument('--lr', default=1e-3, type=float) |
| 24 | +parser.add_argument('--gpu', default='0', type=str) |
| 25 | +args = parser.parse_args() |
| 26 | +restore,is_training =resolve_status(args.train_status) |
| 27 | +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| 28 | +os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu |
| 29 | + |
| 30 | + |
| 31 | +###### Demo 1 # Brain |
| 32 | +image_add = './dataset/demo_brain/img1_input.npy' |
| 33 | +label_add = './dataset/demo_brain/img1_label.npy' |
| 34 | +init_seg_add = './dataset/demo_brain/img1_initseg.npy' |
| 35 | + |
| 36 | + |
| 37 | +def re_init_phi(phi, dt): |
| 38 | + D_left_shift = tf.cast(tf.manip.roll(phi, -1, axis=1), dtype='float32') |
| 39 | + D_right_shift = tf.cast(tf.manip.roll(phi, 1, axis=1), dtype='float32') |
| 40 | + D_up_shift = tf.cast(tf.manip.roll(phi, -1, axis=0), dtype='float32') |
| 41 | + D_down_shift = tf.cast(tf.manip.roll(phi, 1, axis=0), dtype='float32') |
| 42 | + bp = D_left_shift - phi |
| 43 | + cp = phi - D_down_shift |
| 44 | + dp = D_up_shift - phi |
| 45 | + ap = phi - D_right_shift |
| 46 | + an = tf.identity(ap) |
| 47 | + bn = tf.identity(bp) |
| 48 | + cn = tf.identity(cp) |
| 49 | + dn = tf.identity(dp) |
| 50 | + ap = tf.clip_by_value(ap, 0, 10 ^ 38) |
| 51 | + bp = tf.clip_by_value(bp, 0, 10 ^ 38) |
| 52 | + cp = tf.clip_by_value(cp, 0, 10 ^ 38) |
| 53 | + dp = tf.clip_by_value(dp, 0, 10 ^ 38) |
| 54 | + an = tf.clip_by_value(an, -10 ^ 38, 0) |
| 55 | + bn = tf.clip_by_value(bn, -10 ^ 38, 0) |
| 56 | + cn = tf.clip_by_value(cn, -10 ^ 38, 0) |
| 57 | + dn = tf.clip_by_value(dn, -10 ^ 38, 0) |
| 58 | + area_pos = tf.where(phi > 0) |
| 59 | + area_neg = tf.where(phi < 0) |
| 60 | + pos_y = area_pos[:, 0] |
| 61 | + pos_x = area_pos[:, 1] |
| 62 | + neg_y = area_neg[:, 0] |
| 63 | + neg_x = area_neg[:, 1] |
| 64 | + tmp1 = tf.reduce_max([tf.square(tf.gather_nd(t, area_pos)) for t in [ap, bn]], axis=0) |
| 65 | + tmp1 += tf.reduce_max([tf.square(tf.gather_nd(t, area_pos)) for t in [cp, dn]], axis=0) |
| 66 | + update1 = tf.sqrt(tf.abs(tmp1)) - 1 |
| 67 | + indices1 = tf.stack([pos_y, pos_x], 1) |
| 68 | + tmp2 = tf.reduce_max([tf.square(tf.gather_nd(t, area_neg)) for t in [an, bp]], axis=0) |
| 69 | + tmp2 += tf.reduce_max([tf.square(tf.gather_nd(t, area_neg)) for t in [cn, dp]], axis=0) |
| 70 | + update2 = tf.sqrt(tf.abs(tmp2)) - 1 |
| 71 | + indices2 = tf.stack([neg_y, neg_x], 1) |
| 72 | + indices_final = tf.concat([indices1, indices2], 0) |
| 73 | + update_final = tf.concat([update1, update2], 0) |
| 74 | + dD = tf.scatter_nd(indices_final, update_final, shape=[input_image_size, input_image_size]) |
| 75 | + S = tf.divide(phi, tf.square(phi) + 1) |
| 76 | + phi = phi - tf.multiply(dt * S, dD) |
| 77 | + |
| 78 | + return phi |
| 79 | + |
| 80 | + |
| 81 | +def get_curvature(phi, x, y): |
| 82 | + phi_shape = tf.shape(phi) |
| 83 | + dim_y = phi_shape[0] |
| 84 | + dim_x = phi_shape[1] |
| 85 | + x = tf.cast(x, dtype="int32") |
| 86 | + y = tf.cast(y, dtype="int32") |
| 87 | + y_plus = tf.cast(y + 1, dtype="int32") |
| 88 | + y_minus = tf.cast(y - 1, dtype="int32") |
| 89 | + x_plus = tf.cast(x + 1, dtype="int32") |
| 90 | + x_minus = tf.cast(x - 1, dtype="int32") |
| 91 | + y_plus = tf.minimum(tf.cast(y_plus, dtype="int32"), tf.cast(dim_y - 1, dtype="int32")) |
| 92 | + x_plus = tf.minimum(tf.cast(x_plus, dtype="int32"), tf.cast(dim_x - 1, dtype="int32")) |
| 93 | + y_minus = tf.maximum(y_minus, 0) |
| 94 | + x_minus = tf.maximum(x_minus, 0) |
| 95 | + d_phi_dx = tf.gather_nd(phi, tf.stack([y, x_plus], 1)) - tf.gather_nd(phi, tf.stack([y, x_minus], 1)) |
| 96 | + d_phi_dx_2 = tf.square(d_phi_dx) |
| 97 | + d_phi_dy = tf.gather_nd(phi, tf.stack([y_plus, x], 1)) - tf.gather_nd(phi, tf.stack([y_minus, x], 1)) |
| 98 | + d_phi_dy_2 = tf.square(d_phi_dy) |
| 99 | + d_phi_dxx = tf.gather_nd(phi, tf.stack([y, x_plus], 1)) + tf.gather_nd(phi, tf.stack([y, x_minus], 1)) - \ |
| 100 | + 2 * tf.gather_nd(phi, tf.stack([y, x], 1)) |
| 101 | + d_phi_dyy = tf.gather_nd(phi, tf.stack([y_plus, x], 1)) + tf.gather_nd(phi, tf.stack([y_minus, x], 1)) - \ |
| 102 | + 2 * tf.gather_nd(phi, tf.stack([y, x], 1)) |
| 103 | + d_phi_dxy = 0.25 * (- tf.gather_nd(phi, tf.stack([y_minus, x_minus], 1)) - tf.gather_nd(phi, tf.stack( |
| 104 | + [y_plus, x_plus], 1)) + tf.gather_nd(phi, tf.stack([y_minus, x_plus], 1)) + tf.gather_nd(phi, tf.stack( |
| 105 | + [y_plus, x_minus], 1))) |
| 106 | + tmp_1 = tf.multiply(d_phi_dx_2, d_phi_dyy) + tf.multiply(d_phi_dy_2, d_phi_dxx) - \ |
| 107 | + 2 * tf.multiply(tf.multiply(d_phi_dx, d_phi_dy), d_phi_dxy) |
| 108 | + tmp_2 = tf.add(tf.pow(d_phi_dx_2 + d_phi_dy_2, 1.5), 2.220446049250313e-16) |
| 109 | + tmp_3 = tf.pow(d_phi_dx_2 + d_phi_dy_2, 0.5) |
| 110 | + tmp_4 = tf.divide(tmp_1, tmp_2) |
| 111 | + curvature = tf.multiply(tmp_3, tmp_4) |
| 112 | + mean_grad = tf.pow(d_phi_dx_2 + d_phi_dy_2, 0.5) |
| 113 | + |
| 114 | + return curvature, mean_grad |
| 115 | + |
| 116 | + |
| 117 | +def get_intensity(image, masked_phi, filter_patch_size=5): |
| 118 | + u_1 = tf.layers.average_pooling2d(tf.multiply(image, masked_phi), [filter_patch_size, filter_patch_size], 1,padding='SAME') |
| 119 | + u_2 = tf.layers.average_pooling2d(masked_phi, [filter_patch_size, filter_patch_size], 1, padding='SAME') |
| 120 | + u_2_prime = 1 - tf.cast((u_2 > 0), dtype='float32') + tf.cast((u_2 < 0), dtype='float32') |
| 121 | + u_2 = u_2 + u_2_prime + 2.220446049250313e-16 |
| 122 | + |
| 123 | + return tf.divide(u_1, u_2) |
| 124 | + |
| 125 | + |
| 126 | +def active_contour_layer(elems): |
| 127 | + img = elems[0] |
| 128 | + init_phi = elems[1] |
| 129 | + map_lambda1_acl = elems[2] |
| 130 | + map_lambda2_acl = elems[3] |
| 131 | + wind_coef = 3 |
| 132 | + zero_tensor = tf.constant(0, shape=[], dtype="int32") |
| 133 | + def _body(i, phi_level): |
| 134 | + band_index = tf.reduce_all([phi_level <= narrow_band_width, phi_level >= -narrow_band_width], axis=0) |
| 135 | + band = tf.where(band_index) |
| 136 | + band_y = band[:, 0] |
| 137 | + band_x = band[:, 1] |
| 138 | + shape_y = tf.shape(band_y) |
| 139 | + num_band_pixel = shape_y[0] |
| 140 | + window_radii_x = tf.ones(num_band_pixel) * wind_coef |
| 141 | + window_radii_y = tf.ones(num_band_pixel) * wind_coef |
| 142 | + |
| 143 | + def body_intensity(j, mean_intensities_outer, mean_intensities_inner): |
| 144 | + ### This can be computationally expensive. Use with fewer number of acm iterations. |
| 145 | + xnew = tf.cast(band_x[j], dtype="float32") |
| 146 | + ynew = tf.cast(band_y[j], dtype="float32") |
| 147 | + window_radius_x = tf.cast(window_radii_x[j], dtype="float32") |
| 148 | + window_radius_y = tf.cast(window_radii_y[j], dtype="float32") |
| 149 | + local_window_x_min = tf.cast(tf.floor(xnew - window_radius_x), dtype="int32") |
| 150 | + local_window_x_max = tf.cast(tf.floor(xnew + window_radius_x), dtype="int32") |
| 151 | + local_window_y_min = tf.cast(tf.floor(ynew - window_radius_y), dtype="int32") |
| 152 | + local_window_y_max = tf.cast(tf.floor(ynew + window_radius_y), dtype="int32") |
| 153 | + local_window_x_min = tf.maximum(zero_tensor, local_window_x_min) |
| 154 | + local_window_y_min = tf.maximum(zero_tensor, local_window_y_min) |
| 155 | + local_window_x_max = tf.minimum(tf.cast(input_image_size - 1, dtype="int32"), local_window_x_max) |
| 156 | + local_window_y_max = tf.minimum(tf.cast(input_image_size - 1, dtype="int32"), local_window_y_max) |
| 157 | + local_image = img[local_window_y_min: local_window_y_max + 1,local_window_x_min: local_window_x_max + 1] |
| 158 | + local_phi = phi_prime[local_window_y_min: local_window_y_max + 1,local_window_x_min: local_window_x_max + 1] |
| 159 | + inner = tf.where(local_phi <= 0) |
| 160 | + area_inner = tf.cast(tf.shape(inner)[0], dtype='float32') |
| 161 | + outer = tf.where(local_phi > 0) |
| 162 | + area_outer = tf.cast(tf.shape(outer)[0], dtype='float32') |
| 163 | + image_loc_inner = tf.gather_nd(local_image, inner) |
| 164 | + image_loc_outer = tf.gather_nd(local_image, outer) |
| 165 | + mean_intensity_inner = tf.cast(tf.divide(tf.reduce_sum(image_loc_inner), area_inner), dtype='float32') |
| 166 | + mean_intensity_outer = tf.cast(tf.divide(tf.reduce_sum(image_loc_outer), area_outer), dtype='float32') |
| 167 | + mean_intensities_inner = tf.concat(axis=0, values=[mean_intensities_inner[:j], [mean_intensity_inner]]) |
| 168 | + mean_intensities_outer = tf.concat(axis=0, values=[mean_intensities_outer[:j], [mean_intensity_outer]]) |
| 169 | + |
| 170 | + return (j + 1, mean_intensities_outer, mean_intensities_inner) |
| 171 | + |
| 172 | + if fast_lookup: |
| 173 | + phi_4d = phi_level[tf.newaxis, :, :, tf.newaxis] |
| 174 | + image = img[tf.newaxis, :, :, tf.newaxis] |
| 175 | + band_index_2 = tf.reduce_all([phi_4d <= narrow_band_width, phi_4d >= -narrow_band_width], axis=0) |
| 176 | + band_2 = tf.where(band_index_2) |
| 177 | + u_inner = get_intensity(image, tf.cast((([phi_4d <= 0])), dtype='float32')[0], filter_patch_size=f_size) |
| 178 | + u_outer = get_intensity(image, tf.cast((([phi_4d > 0])), dtype='float32')[0], filter_patch_size=f_size) |
| 179 | + mean_intensities_inner = tf.gather_nd(u_inner, band_2) |
| 180 | + mean_intensities_outer = tf.gather_nd(u_outer, band_2) |
| 181 | + |
| 182 | + else: |
| 183 | + mean_intensities_inner = tf.constant([0], dtype='float32') |
| 184 | + mean_intensities_outer = tf.constant([0], dtype='float32') |
| 185 | + j = tf.constant(0, dtype=tf.int32) |
| 186 | + _, mean_intensities_outer, mean_intensities_inner = tf.while_loop( |
| 187 | + lambda j, mean_intensities_outer, mean_intensities_inner: |
| 188 | + j < num_band_pixel, body_intensity, loop_vars=[j, mean_intensities_outer, mean_intensities_inner], |
| 189 | + shape_invariants=[j.get_shape(), tf.TensorShape([None]), tf.TensorShape([None])]) |
| 190 | + |
| 191 | + lambda1 = tf.gather_nd(map_lambda1_acl, [band]) |
| 192 | + lambda2 = tf.gather_nd(map_lambda2_acl, [band]) |
| 193 | + curvature, mean_grad = get_curvature(phi_level, band_x, band_y) |
| 194 | + kappa = tf.multiply(curvature, mean_grad) |
| 195 | + term1 = tf.multiply(tf.cast(lambda1, dtype='float32'),tf.square(tf.gather_nd(img, [band]) - mean_intensities_inner)) |
| 196 | + term2 = tf.multiply(tf.cast(lambda2, dtype='float32'),tf.square(tf.gather_nd(img, [band]) - mean_intensities_outer)) |
| 197 | + force = -nu + term1 - term2 |
| 198 | + force /= (tf.reduce_max(tf.abs(force))) |
| 199 | + d_phi_dt = tf.cast(force, dtype="float32") + tf.cast(mu * kappa, dtype="float32") |
| 200 | + dt = .45 / (tf.reduce_max(tf.abs(d_phi_dt)) + 2.220446049250313e-16) |
| 201 | + d_phi = dt * d_phi_dt |
| 202 | + update_narrow_band = d_phi |
| 203 | + phi_prime = phi_level + tf.scatter_nd([band], tf.cast(update_narrow_band, dtype='float32'),shape=[input_image_size, input_image_size]) |
| 204 | + phi_prime = re_init_phi(phi_prime, 0.5) |
| 205 | + |
| 206 | + return (i + 1, phi_prime) |
| 207 | + |
| 208 | + i = tf.constant(0, dtype=tf.int32) |
| 209 | + phi = init_phi |
| 210 | + _, phi = tf.while_loop(lambda i, phi: i < iter_limit, _body, loop_vars=[i, phi]) |
| 211 | + phi = tf.round(tf.cast((1 - tf.nn.sigmoid(phi)), dtype=tf.float32)) |
| 212 | + |
| 213 | + return phi,init_phi, map_lambda1_acl, map_lambda2_acl |
| 214 | + |
| 215 | +fast_lookup = True |
| 216 | +config = tf.ConfigProto(allow_soft_placement=True) |
| 217 | +input_shape = [args.batch_size, args.img_resize, args.img_resize, 1] |
| 218 | +input_shape_dt = [args.batch_size, args.img_resize, args.img_resize] |
| 219 | +iter_limit = args.acm_iter_limit |
| 220 | +narrow_band_width = args.narrow_band_width |
| 221 | +mu = args.mu |
| 222 | +nu = args.nu |
| 223 | +f_size = args.f_size |
| 224 | +input_image_size = args.img_resize |
| 225 | +x = tf.placeholder(shape=input_shape, dtype=tf.float32, name="x") |
| 226 | +y = tf.placeholder(dtype=tf.float32, name="y") |
| 227 | +out_seg = tf.placeholder(dtype=tf.float32, name="out_seg") |
| 228 | +phase = tf.placeholder(tf.bool, name='phase') |
| 229 | +global_step = tf.Variable(0, name='global_step', trainable=False) |
| 230 | +map_lambda1 = tf.exp(tf.divide(tf.subtract(2.0,out_seg),tf.add(1.0,out_seg))) |
| 231 | +map_lambda2 = tf.exp(tf.divide(tf.add(1.0, out_seg), tf.subtract(2.0, out_seg))) |
| 232 | +y_out_dl = tf.round(out_seg) |
| 233 | +x_acm = x[:, :, :, 0] |
| 234 | +rounded_seg_acl = y_out_dl[:, :, :, 0] |
| 235 | +dt_trans = tf.py_func(my_func, [rounded_seg_acl], tf.float32) |
| 236 | +dt_trans.set_shape([args.batch_size, input_image_size, input_image_size]) |
| 237 | +phi_out,_, lambda1_tr, lambda2_tr = tf.map_fn(fn=active_contour_layer, elems=(x_acm, dt_trans, map_lambda1[:, :, :, 0], map_lambda2[:, :, :, 0])) |
| 238 | +rounded_seg = tf.round(out_seg) |
| 239 | +with tf.Session(config=config) as sess: |
| 240 | + |
| 241 | + print("########### Inference ############") |
| 242 | + |
| 243 | + print('Brain Demo in Progress ... ') |
| 244 | + image = load_image(image_add,args.batch_size,False) |
| 245 | + labels = load_image(label_add, args.batch_size,True) |
| 246 | + init_seg = np.load(init_seg_add) |
| 247 | + labels[labels != 0] = 1 |
| 248 | + seg_out_acm, seg_out = sess.run([phi_out, y_out_dl],{x: image, y: labels, out_seg: init_seg, phase: False}) |
| 249 | + seg_out = seg_out[0, :, :, 0] |
| 250 | + seg_out_acm = seg_out_acm[0, :, :] |
| 251 | + gt_mask = labels[0, :, :, 0] |
| 252 | + f1 = f1_score(gt_mask, seg_out, labels=None, average='micro', sample_weight=None) |
| 253 | + print('CNN Dice {0:0.4f}'.format(f1)) |
| 254 | + f2 = f1_score(gt_mask, seg_out_acm, labels=None, average='micro', sample_weight=None) |
| 255 | + print('ACM Dice {0:0.4f}'.format(f2)) |
| 256 | + fig = plt.figure() |
| 257 | + plt.subplot(1, 3, 1) |
| 258 | + plt.title('DALS Output, Dice:{0:0.4f}'.format(f2)) |
| 259 | + seg_out_acm=contoured_image(seg_out_acm, image[0,:,:,0]) |
| 260 | + plt.imshow(seg_out_acm) |
| 261 | + plt.subplot(1, 3, 2) |
| 262 | + plt.title('CNN Output, Dice:{0:0.4f}'.format(f1)) |
| 263 | + seg_out = contoured_image(seg_out, image[0, :, :, 0]) |
| 264 | + plt.imshow(seg_out) |
| 265 | + plt.subplot(1, 3, 3) |
| 266 | + plt.title('Radiologist Annotation') |
| 267 | + gt_mask = contoured_image(gt_mask, image[0, :, :, 0]) |
| 268 | + plt.imshow(gt_mask) |
| 269 | + plt.show() |
| 270 | + |
| 271 | + |
| 272 | + |
| 273 | + |
0 commit comments