Skip to content

Commit 16b211d

Browse files
committed
dseg-ae update
1 parent 8c9fca3 commit 16b211d

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

bash_scripts/test_dseg_ae.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22

33
run () {
4-
cmd="python3.9 test_ae_regularization.py --ds ts --gpu "$GPU" --output results/DSEGAE_"$OUT_SUFFIX"_"$1"_"$2" --seg_dir results/DGCNN_seg_"$1"_"$2" --ae_dir "$AE_DIR" --speed"
4+
cmd="python3.9 test_ae_regularization.py --ds ts --gpu "$GPU" --output results/DSEGAE_"$OUT_SUFFIX"_"$1"_"$2" --seg_dir results/DGCNN_seg_"$1"_"$2" --ae_dir "$AE_DIR""
55
echo "#######################################################################################################################################################################"
66
echo $cmd
77
echo "#######################################################################################################################################################################"
@@ -12,7 +12,7 @@ GPU=2
1212
AE_DIR="results/PC_AE_regularized_augment_1024"
1313
OUT_SUFFIX="reg_aug_1024"
1414
keypoints=("cnn" "foerstner" "enhancement")
15-
features=("image" "mind" "mind_ssc" "enhancement" "nofeat")
15+
#features=("image" "mind" "mind_ssc" "enhancement" "nofeat")
1616

1717
for kp in "${keypoints[@]}"
1818
do
@@ -22,8 +22,9 @@ for kp in "${keypoints[@]}"
2222
features_cur=("${features[@]}")
2323
fi
2424

25-
for feat in "${features_cur[@]}"
26-
do
27-
run "$kp" "$feat"
28-
done
25+
run "$kp" "image"
26+
# for feat in "${features_cur[@]}"
27+
# do
28+
# run "$kp" "$feat"
29+
# done
2930
done

test_ae_regularization.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from data_processing.surface_fitting import o3d_mesh_to_labelmap
1414
from metrics import assd, pseudo_symmetric_point_to_mesh_distance
1515
from models.dgcnn import DGCNNSeg
16-
from models.folding_net import DGCNNFoldingNet
16+
from models.folding_net import DGCNNFoldingNet, FoldingDecoder
1717
from models.modelio import LoadableModel, store_config_args
1818
from train import write_results, run, write_speed_results
1919
from utils.detached_run import maybe_run_detached_cli
@@ -125,7 +125,7 @@ def test(ds: PointToMeshDS, device, out_dir, show):
125125

126126
if output_is_mesh:
127127
image_ds = ImageDataset(ds.image_folder, do_augmentation=False)
128-
label_dir = new_dir(out_dir, 'test_predictions', 'label_maps')
128+
label_dir = new_dir(out_dir, 'test_predictions', 'labelmaps')
129129

130130
# pred_dir = new_dir(out_dir, 'test_predictions')
131131
plot_dir = new_dir(out_dir, 'plots')
@@ -151,6 +151,10 @@ def test(ds: PointToMeshDS, device, out_dir, show):
151151
else:
152152
plt.close(fig)
153153

154+
if isinstance(model.ae.decoder, FoldingDecoder):
155+
# show only the points from now on
156+
color_values = color_2d_points_bremm(folding_points[:, :2])
157+
154158
chamfer_dists = torch.zeros(len(ds.ids), ds.num_classes - 1)
155159
all_mean_assd = torch.zeros_like(chamfer_dists)
156160
all_mean_sdsd = torch.zeros_like(chamfer_dists)
@@ -199,16 +203,20 @@ def test(ds: PointToMeshDS, device, out_dir, show):
199203

200204
# visualize reconstruction
201205
fig = plt.figure()
202-
if output_is_mesh:
206+
if output_is_mesh and not isinstance(model.ae.decoder, FoldingDecoder):
203207
ax1 = fig.add_subplot(111, projection='3d')
204208
# point_cloud_on_axis(ax1, input_coords.cpu(), 'b', label='input', alpha=0.3)
205209
point_cloud_on_axis(ax1, segmented_sampled_coords.cpu(), 'k', label='segmented points', alpha=0.3)
206210
trimesh_on_axis(ax1, reconstruct_obj.verts_padded().cpu().squeeze(), faces, facecolors=color_values, alpha=0.7, label='reconstruction')
207211
else:
212+
if isinstance(model.ae.decoder, FoldingDecoder):
213+
points = reconstruct_obj.verts_padded().cpu().squeeze()
214+
else:
215+
points = reconstruct_obj.cpu()
208216
ax1 = fig.add_subplot(121, projection='3d')
209217
ax2 = fig.add_subplot(122, projection='3d')
210218
point_cloud_on_axis(ax1, input_coords.cpu(), 'k', title='input')
211-
point_cloud_on_axis(ax2, reconstruct_obj.cpu(), color_values, title='reconstruction')
219+
point_cloud_on_axis(ax2, points, color_values, title='reconstruction')
212220

213221
fig.savefig(os.path.join(plot_dir,
214222
f'{"_".join(ds.ids[i])}_{"fissure" if not ds.lobes else "lobe"}{cur_obj+1}_reconstruction.png'),

0 commit comments

Comments
 (0)