Skip to content

Commit 82775b0

Browse files
authored
main_demo added
1 parent ce4d738 commit 82775b0

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed

main_demo.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
## Deep Active Lesion Segmention (DALS), Code by Ali Hatamizadeh ( http://web.cs.ucla.edu/~ahatamiz/ )
2+
3+
import os
4+
import numpy as np
5+
import tensorflow as tf
6+
from sklearn.metrics import f1_score
7+
from utils import load_image,my_func,resolve_status,contoured_image
8+
import matplotlib.pyplot as plt
9+
import argparse
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument('--logdir', default='network_lung', type=str)
12+
parser.add_argument('--mu', default=0.2, type=float)
13+
parser.add_argument('--nu', default=5.0, type=float)
14+
parser.add_argument('--batch_size', default=1, type=int)
15+
parser.add_argument('--train_sum_freq', default=150, type=int)
16+
parser.add_argument('--train_iter', default=150000, type=int)
17+
parser.add_argument('--acm_iter_limit', default=300, type=int)
18+
parser.add_argument('--img_resize', default=512, type=int)
19+
parser.add_argument('--f_size', default=15, type=int)
20+
parser.add_argument('--train_status', default=1, type=int)
21+
parser.add_argument('--narrow_band_width', default=1, type=int)
22+
parser.add_argument('--save_freq', default=1000, type=int)
23+
parser.add_argument('--lr', default=1e-3, type=float)
24+
parser.add_argument('--gpu', default='0', type=str)
25+
args = parser.parse_args()
26+
restore,is_training =resolve_status(args.train_status)
27+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
28+
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
29+
30+
31+
###### Demo 1 # Brain
32+
image_add = './dataset/demo_brain/img1_input.npy'
33+
label_add = './dataset/demo_brain/img1_label.npy'
34+
init_seg_add = './dataset/demo_brain/img1_initseg.npy'
35+
36+
37+
def re_init_phi(phi, dt):
38+
D_left_shift = tf.cast(tf.manip.roll(phi, -1, axis=1), dtype='float32')
39+
D_right_shift = tf.cast(tf.manip.roll(phi, 1, axis=1), dtype='float32')
40+
D_up_shift = tf.cast(tf.manip.roll(phi, -1, axis=0), dtype='float32')
41+
D_down_shift = tf.cast(tf.manip.roll(phi, 1, axis=0), dtype='float32')
42+
bp = D_left_shift - phi
43+
cp = phi - D_down_shift
44+
dp = D_up_shift - phi
45+
ap = phi - D_right_shift
46+
an = tf.identity(ap)
47+
bn = tf.identity(bp)
48+
cn = tf.identity(cp)
49+
dn = tf.identity(dp)
50+
ap = tf.clip_by_value(ap, 0, 10 ^ 38)
51+
bp = tf.clip_by_value(bp, 0, 10 ^ 38)
52+
cp = tf.clip_by_value(cp, 0, 10 ^ 38)
53+
dp = tf.clip_by_value(dp, 0, 10 ^ 38)
54+
an = tf.clip_by_value(an, -10 ^ 38, 0)
55+
bn = tf.clip_by_value(bn, -10 ^ 38, 0)
56+
cn = tf.clip_by_value(cn, -10 ^ 38, 0)
57+
dn = tf.clip_by_value(dn, -10 ^ 38, 0)
58+
area_pos = tf.where(phi > 0)
59+
area_neg = tf.where(phi < 0)
60+
pos_y = area_pos[:, 0]
61+
pos_x = area_pos[:, 1]
62+
neg_y = area_neg[:, 0]
63+
neg_x = area_neg[:, 1]
64+
tmp1 = tf.reduce_max([tf.square(tf.gather_nd(t, area_pos)) for t in [ap, bn]], axis=0)
65+
tmp1 += tf.reduce_max([tf.square(tf.gather_nd(t, area_pos)) for t in [cp, dn]], axis=0)
66+
update1 = tf.sqrt(tf.abs(tmp1)) - 1
67+
indices1 = tf.stack([pos_y, pos_x], 1)
68+
tmp2 = tf.reduce_max([tf.square(tf.gather_nd(t, area_neg)) for t in [an, bp]], axis=0)
69+
tmp2 += tf.reduce_max([tf.square(tf.gather_nd(t, area_neg)) for t in [cn, dp]], axis=0)
70+
update2 = tf.sqrt(tf.abs(tmp2)) - 1
71+
indices2 = tf.stack([neg_y, neg_x], 1)
72+
indices_final = tf.concat([indices1, indices2], 0)
73+
update_final = tf.concat([update1, update2], 0)
74+
dD = tf.scatter_nd(indices_final, update_final, shape=[input_image_size, input_image_size])
75+
S = tf.divide(phi, tf.square(phi) + 1)
76+
phi = phi - tf.multiply(dt * S, dD)
77+
78+
return phi
79+
80+
81+
def get_curvature(phi, x, y):
82+
phi_shape = tf.shape(phi)
83+
dim_y = phi_shape[0]
84+
dim_x = phi_shape[1]
85+
x = tf.cast(x, dtype="int32")
86+
y = tf.cast(y, dtype="int32")
87+
y_plus = tf.cast(y + 1, dtype="int32")
88+
y_minus = tf.cast(y - 1, dtype="int32")
89+
x_plus = tf.cast(x + 1, dtype="int32")
90+
x_minus = tf.cast(x - 1, dtype="int32")
91+
y_plus = tf.minimum(tf.cast(y_plus, dtype="int32"), tf.cast(dim_y - 1, dtype="int32"))
92+
x_plus = tf.minimum(tf.cast(x_plus, dtype="int32"), tf.cast(dim_x - 1, dtype="int32"))
93+
y_minus = tf.maximum(y_minus, 0)
94+
x_minus = tf.maximum(x_minus, 0)
95+
d_phi_dx = tf.gather_nd(phi, tf.stack([y, x_plus], 1)) - tf.gather_nd(phi, tf.stack([y, x_minus], 1))
96+
d_phi_dx_2 = tf.square(d_phi_dx)
97+
d_phi_dy = tf.gather_nd(phi, tf.stack([y_plus, x], 1)) - tf.gather_nd(phi, tf.stack([y_minus, x], 1))
98+
d_phi_dy_2 = tf.square(d_phi_dy)
99+
d_phi_dxx = tf.gather_nd(phi, tf.stack([y, x_plus], 1)) + tf.gather_nd(phi, tf.stack([y, x_minus], 1)) - \
100+
2 * tf.gather_nd(phi, tf.stack([y, x], 1))
101+
d_phi_dyy = tf.gather_nd(phi, tf.stack([y_plus, x], 1)) + tf.gather_nd(phi, tf.stack([y_minus, x], 1)) - \
102+
2 * tf.gather_nd(phi, tf.stack([y, x], 1))
103+
d_phi_dxy = 0.25 * (- tf.gather_nd(phi, tf.stack([y_minus, x_minus], 1)) - tf.gather_nd(phi, tf.stack(
104+
[y_plus, x_plus], 1)) + tf.gather_nd(phi, tf.stack([y_minus, x_plus], 1)) + tf.gather_nd(phi, tf.stack(
105+
[y_plus, x_minus], 1)))
106+
tmp_1 = tf.multiply(d_phi_dx_2, d_phi_dyy) + tf.multiply(d_phi_dy_2, d_phi_dxx) - \
107+
2 * tf.multiply(tf.multiply(d_phi_dx, d_phi_dy), d_phi_dxy)
108+
tmp_2 = tf.add(tf.pow(d_phi_dx_2 + d_phi_dy_2, 1.5), 2.220446049250313e-16)
109+
tmp_3 = tf.pow(d_phi_dx_2 + d_phi_dy_2, 0.5)
110+
tmp_4 = tf.divide(tmp_1, tmp_2)
111+
curvature = tf.multiply(tmp_3, tmp_4)
112+
mean_grad = tf.pow(d_phi_dx_2 + d_phi_dy_2, 0.5)
113+
114+
return curvature, mean_grad
115+
116+
117+
def get_intensity(image, masked_phi, filter_patch_size=5):
118+
u_1 = tf.layers.average_pooling2d(tf.multiply(image, masked_phi), [filter_patch_size, filter_patch_size], 1,padding='SAME')
119+
u_2 = tf.layers.average_pooling2d(masked_phi, [filter_patch_size, filter_patch_size], 1, padding='SAME')
120+
u_2_prime = 1 - tf.cast((u_2 > 0), dtype='float32') + tf.cast((u_2 < 0), dtype='float32')
121+
u_2 = u_2 + u_2_prime + 2.220446049250313e-16
122+
123+
return tf.divide(u_1, u_2)
124+
125+
126+
def active_contour_layer(elems):
127+
img = elems[0]
128+
init_phi = elems[1]
129+
map_lambda1_acl = elems[2]
130+
map_lambda2_acl = elems[3]
131+
wind_coef = 3
132+
zero_tensor = tf.constant(0, shape=[], dtype="int32")
133+
def _body(i, phi_level):
134+
band_index = tf.reduce_all([phi_level <= narrow_band_width, phi_level >= -narrow_band_width], axis=0)
135+
band = tf.where(band_index)
136+
band_y = band[:, 0]
137+
band_x = band[:, 1]
138+
shape_y = tf.shape(band_y)
139+
num_band_pixel = shape_y[0]
140+
window_radii_x = tf.ones(num_band_pixel) * wind_coef
141+
window_radii_y = tf.ones(num_band_pixel) * wind_coef
142+
143+
def body_intensity(j, mean_intensities_outer, mean_intensities_inner):
144+
### This can be computationally expensive. Use with fewer number of acm iterations.
145+
xnew = tf.cast(band_x[j], dtype="float32")
146+
ynew = tf.cast(band_y[j], dtype="float32")
147+
window_radius_x = tf.cast(window_radii_x[j], dtype="float32")
148+
window_radius_y = tf.cast(window_radii_y[j], dtype="float32")
149+
local_window_x_min = tf.cast(tf.floor(xnew - window_radius_x), dtype="int32")
150+
local_window_x_max = tf.cast(tf.floor(xnew + window_radius_x), dtype="int32")
151+
local_window_y_min = tf.cast(tf.floor(ynew - window_radius_y), dtype="int32")
152+
local_window_y_max = tf.cast(tf.floor(ynew + window_radius_y), dtype="int32")
153+
local_window_x_min = tf.maximum(zero_tensor, local_window_x_min)
154+
local_window_y_min = tf.maximum(zero_tensor, local_window_y_min)
155+
local_window_x_max = tf.minimum(tf.cast(input_image_size - 1, dtype="int32"), local_window_x_max)
156+
local_window_y_max = tf.minimum(tf.cast(input_image_size - 1, dtype="int32"), local_window_y_max)
157+
local_image = img[local_window_y_min: local_window_y_max + 1,local_window_x_min: local_window_x_max + 1]
158+
local_phi = phi_prime[local_window_y_min: local_window_y_max + 1,local_window_x_min: local_window_x_max + 1]
159+
inner = tf.where(local_phi <= 0)
160+
area_inner = tf.cast(tf.shape(inner)[0], dtype='float32')
161+
outer = tf.where(local_phi > 0)
162+
area_outer = tf.cast(tf.shape(outer)[0], dtype='float32')
163+
image_loc_inner = tf.gather_nd(local_image, inner)
164+
image_loc_outer = tf.gather_nd(local_image, outer)
165+
mean_intensity_inner = tf.cast(tf.divide(tf.reduce_sum(image_loc_inner), area_inner), dtype='float32')
166+
mean_intensity_outer = tf.cast(tf.divide(tf.reduce_sum(image_loc_outer), area_outer), dtype='float32')
167+
mean_intensities_inner = tf.concat(axis=0, values=[mean_intensities_inner[:j], [mean_intensity_inner]])
168+
mean_intensities_outer = tf.concat(axis=0, values=[mean_intensities_outer[:j], [mean_intensity_outer]])
169+
170+
return (j + 1, mean_intensities_outer, mean_intensities_inner)
171+
172+
if fast_lookup:
173+
phi_4d = phi_level[tf.newaxis, :, :, tf.newaxis]
174+
image = img[tf.newaxis, :, :, tf.newaxis]
175+
band_index_2 = tf.reduce_all([phi_4d <= narrow_band_width, phi_4d >= -narrow_band_width], axis=0)
176+
band_2 = tf.where(band_index_2)
177+
u_inner = get_intensity(image, tf.cast((([phi_4d <= 0])), dtype='float32')[0], filter_patch_size=f_size)
178+
u_outer = get_intensity(image, tf.cast((([phi_4d > 0])), dtype='float32')[0], filter_patch_size=f_size)
179+
mean_intensities_inner = tf.gather_nd(u_inner, band_2)
180+
mean_intensities_outer = tf.gather_nd(u_outer, band_2)
181+
182+
else:
183+
mean_intensities_inner = tf.constant([0], dtype='float32')
184+
mean_intensities_outer = tf.constant([0], dtype='float32')
185+
j = tf.constant(0, dtype=tf.int32)
186+
_, mean_intensities_outer, mean_intensities_inner = tf.while_loop(
187+
lambda j, mean_intensities_outer, mean_intensities_inner:
188+
j < num_band_pixel, body_intensity, loop_vars=[j, mean_intensities_outer, mean_intensities_inner],
189+
shape_invariants=[j.get_shape(), tf.TensorShape([None]), tf.TensorShape([None])])
190+
191+
lambda1 = tf.gather_nd(map_lambda1_acl, [band])
192+
lambda2 = tf.gather_nd(map_lambda2_acl, [band])
193+
curvature, mean_grad = get_curvature(phi_level, band_x, band_y)
194+
kappa = tf.multiply(curvature, mean_grad)
195+
term1 = tf.multiply(tf.cast(lambda1, dtype='float32'),tf.square(tf.gather_nd(img, [band]) - mean_intensities_inner))
196+
term2 = tf.multiply(tf.cast(lambda2, dtype='float32'),tf.square(tf.gather_nd(img, [band]) - mean_intensities_outer))
197+
force = -nu + term1 - term2
198+
force /= (tf.reduce_max(tf.abs(force)))
199+
d_phi_dt = tf.cast(force, dtype="float32") + tf.cast(mu * kappa, dtype="float32")
200+
dt = .45 / (tf.reduce_max(tf.abs(d_phi_dt)) + 2.220446049250313e-16)
201+
d_phi = dt * d_phi_dt
202+
update_narrow_band = d_phi
203+
phi_prime = phi_level + tf.scatter_nd([band], tf.cast(update_narrow_band, dtype='float32'),shape=[input_image_size, input_image_size])
204+
phi_prime = re_init_phi(phi_prime, 0.5)
205+
206+
return (i + 1, phi_prime)
207+
208+
i = tf.constant(0, dtype=tf.int32)
209+
phi = init_phi
210+
_, phi = tf.while_loop(lambda i, phi: i < iter_limit, _body, loop_vars=[i, phi])
211+
phi = tf.round(tf.cast((1 - tf.nn.sigmoid(phi)), dtype=tf.float32))
212+
213+
return phi,init_phi, map_lambda1_acl, map_lambda2_acl
214+
215+
fast_lookup = True
216+
config = tf.ConfigProto(allow_soft_placement=True)
217+
input_shape = [args.batch_size, args.img_resize, args.img_resize, 1]
218+
input_shape_dt = [args.batch_size, args.img_resize, args.img_resize]
219+
iter_limit = args.acm_iter_limit
220+
narrow_band_width = args.narrow_band_width
221+
mu = args.mu
222+
nu = args.nu
223+
f_size = args.f_size
224+
input_image_size = args.img_resize
225+
x = tf.placeholder(shape=input_shape, dtype=tf.float32, name="x")
226+
y = tf.placeholder(dtype=tf.float32, name="y")
227+
out_seg = tf.placeholder(dtype=tf.float32, name="out_seg")
228+
phase = tf.placeholder(tf.bool, name='phase')
229+
global_step = tf.Variable(0, name='global_step', trainable=False)
230+
map_lambda1 = tf.exp(tf.divide(tf.subtract(2.0,out_seg),tf.add(1.0,out_seg)))
231+
map_lambda2 = tf.exp(tf.divide(tf.add(1.0, out_seg), tf.subtract(2.0, out_seg)))
232+
y_out_dl = tf.round(out_seg)
233+
x_acm = x[:, :, :, 0]
234+
rounded_seg_acl = y_out_dl[:, :, :, 0]
235+
dt_trans = tf.py_func(my_func, [rounded_seg_acl], tf.float32)
236+
dt_trans.set_shape([args.batch_size, input_image_size, input_image_size])
237+
phi_out,_, lambda1_tr, lambda2_tr = tf.map_fn(fn=active_contour_layer, elems=(x_acm, dt_trans, map_lambda1[:, :, :, 0], map_lambda2[:, :, :, 0]))
238+
rounded_seg = tf.round(out_seg)
239+
with tf.Session(config=config) as sess:
240+
241+
print("########### Inference ############")
242+
243+
print('Brain Demo in Progress ... ')
244+
image = load_image(image_add,args.batch_size,False)
245+
labels = load_image(label_add, args.batch_size,True)
246+
init_seg = np.load(init_seg_add)
247+
labels[labels != 0] = 1
248+
seg_out_acm, seg_out = sess.run([phi_out, y_out_dl],{x: image, y: labels, out_seg: init_seg, phase: False})
249+
seg_out = seg_out[0, :, :, 0]
250+
seg_out_acm = seg_out_acm[0, :, :]
251+
gt_mask = labels[0, :, :, 0]
252+
f1 = f1_score(gt_mask, seg_out, labels=None, average='micro', sample_weight=None)
253+
print('CNN Dice {0:0.4f}'.format(f1))
254+
f2 = f1_score(gt_mask, seg_out_acm, labels=None, average='micro', sample_weight=None)
255+
print('ACM Dice {0:0.4f}'.format(f2))
256+
fig = plt.figure()
257+
plt.subplot(1, 3, 1)
258+
plt.title('DALS Output, Dice:{0:0.4f}'.format(f2))
259+
seg_out_acm=contoured_image(seg_out_acm, image[0,:,:,0])
260+
plt.imshow(seg_out_acm)
261+
plt.subplot(1, 3, 2)
262+
plt.title('CNN Output, Dice:{0:0.4f}'.format(f1))
263+
seg_out = contoured_image(seg_out, image[0, :, :, 0])
264+
plt.imshow(seg_out)
265+
plt.subplot(1, 3, 3)
266+
plt.title('Radiologist Annotation')
267+
gt_mask = contoured_image(gt_mask, image[0, :, :, 0])
268+
plt.imshow(gt_mask)
269+
plt.show()
270+
271+
272+
273+

0 commit comments

Comments
 (0)