13
13
from data_processing .surface_fitting import o3d_mesh_to_labelmap
14
14
from metrics import assd , pseudo_symmetric_point_to_mesh_distance
15
15
from models .dgcnn import DGCNNSeg
16
- from models .folding_net import DGCNNFoldingNet
16
+ from models .folding_net import DGCNNFoldingNet , FoldingDecoder
17
17
from models .modelio import LoadableModel , store_config_args
18
18
from train import write_results , run , write_speed_results
19
19
from utils .detached_run import maybe_run_detached_cli
@@ -125,7 +125,7 @@ def test(ds: PointToMeshDS, device, out_dir, show):
125
125
126
126
if output_is_mesh :
127
127
image_ds = ImageDataset (ds .image_folder , do_augmentation = False )
128
- label_dir = new_dir (out_dir , 'test_predictions' , 'label_maps ' )
128
+ label_dir = new_dir (out_dir , 'test_predictions' , 'labelmaps ' )
129
129
130
130
# pred_dir = new_dir(out_dir, 'test_predictions')
131
131
plot_dir = new_dir (out_dir , 'plots' )
@@ -151,6 +151,10 @@ def test(ds: PointToMeshDS, device, out_dir, show):
151
151
else :
152
152
plt .close (fig )
153
153
154
+ if isinstance (model .ae .decoder , FoldingDecoder ):
155
+ # show only the points from now on
156
+ color_values = color_2d_points_bremm (folding_points [:, :2 ])
157
+
154
158
chamfer_dists = torch .zeros (len (ds .ids ), ds .num_classes - 1 )
155
159
all_mean_assd = torch .zeros_like (chamfer_dists )
156
160
all_mean_sdsd = torch .zeros_like (chamfer_dists )
@@ -199,16 +203,20 @@ def test(ds: PointToMeshDS, device, out_dir, show):
199
203
200
204
# visualize reconstruction
201
205
fig = plt .figure ()
202
- if output_is_mesh :
206
+ if output_is_mesh and not isinstance ( model . ae . decoder , FoldingDecoder ) :
203
207
ax1 = fig .add_subplot (111 , projection = '3d' )
204
208
# point_cloud_on_axis(ax1, input_coords.cpu(), 'b', label='input', alpha=0.3)
205
209
point_cloud_on_axis (ax1 , segmented_sampled_coords .cpu (), 'k' , label = 'segmented points' , alpha = 0.3 )
206
210
trimesh_on_axis (ax1 , reconstruct_obj .verts_padded ().cpu ().squeeze (), faces , facecolors = color_values , alpha = 0.7 , label = 'reconstruction' )
207
211
else :
212
+ if isinstance (model .ae .decoder , FoldingDecoder ):
213
+ points = reconstruct_obj .verts_padded ().cpu ().squeeze ()
214
+ else :
215
+ points = reconstruct_obj .cpu ()
208
216
ax1 = fig .add_subplot (121 , projection = '3d' )
209
217
ax2 = fig .add_subplot (122 , projection = '3d' )
210
218
point_cloud_on_axis (ax1 , input_coords .cpu (), 'k' , title = 'input' )
211
- point_cloud_on_axis (ax2 , reconstruct_obj . cpu () , color_values , title = 'reconstruction' )
219
+ point_cloud_on_axis (ax2 , points , color_values , title = 'reconstruction' )
212
220
213
221
fig .savefig (os .path .join (plot_dir ,
214
222
f'{ "_" .join (ds .ids [i ])} _{ "fissure" if not ds .lobes else "lobe" } { cur_obj + 1 } _reconstruction.png' ),
0 commit comments