Skip to content

Commit b6044cb

Browse files
committed
[ENH] Mesh extraction on pytorch
1 parent 6779a49 commit b6044cb

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

gempy_engine/API/dual_contouring/_dual_contouring.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import numpy as np
55

6+
from gempy_engine.config import AvailableBackends
7+
8+
from ...core.backend_tensor import BackendTensor
69
from ...core.data.dual_contouring_data import DualContouringData
710
from ...core.data.dual_contouring_mesh import DualContouringMesh
811
from ...core.utils import gempy_profiler_decorator
@@ -49,12 +52,12 @@ def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_co
4952
# * Fancy triangulation 👗
5053

5154
# * Average gradient for the edges
52-
from gempy_engine.core.backend_tensor import BackendTensor
5355
edges_normals = BackendTensor.t.zeros((valid_edges.shape[0], 12, 3), dtype=BackendTensor.dtype_obj)
5456
edges_normals[:] = np.nan
5557
edges_normals[valid_edges] = dc_data_per_stack.gradients[slice_object]
5658

57-
if LEGACY:=False:
59+
# if LEGACY:=True:
60+
if BackendTensor.engine_backend != AvailableBackends.PYTORCH:
5861
with warnings.catch_warnings():
5962
warnings.simplefilter("ignore", category=RuntimeWarning)
6063
voxel_normal = np.nanmean(edges_normals, axis=1)
@@ -76,7 +79,7 @@ def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_co
7679
voxel_normal = sum_normals / valid_count.clamp(min=1)
7780

7881
# Remove rows where all elements were NaN (and hence valid_count is 0)
79-
voxel_normal = voxel_normal[valid_count > 0]
82+
voxel_normal = voxel_normal[valid_count > 0].reshape(-1, 3)
8083

8184

8285
valid_voxels = dc_data_per_surface.valid_voxels
@@ -89,5 +92,11 @@ def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_co
8992
indices = np.vstack(indices)
9093

9194
# @on
92-
stack_meshes.append(DualContouringMesh(vertices, indices, dc_data_per_stack))
95+
stack_meshes.append(
96+
DualContouringMesh(
97+
BackendTensor.t.to_numpy(vertices),
98+
indices,
99+
dc_data_per_stack
100+
)
101+
)
93102
return stack_meshes

gempy_engine/modules/dual_contouring/dual_contouring_interface.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
173173
bias_xyz[isclose] = np.nan # np zero values to nans
174174
mass_points = np.nanmean(bias_xyz, axis=1) # Mean ignoring nans
175175
else: # ? This is actually doing something
176-
bias_xyz = BackendTensor.t.copy(edges_xyz[:, :12]).detach().numpy()
176+
bias_xyz = BackendTensor.t.copy(edges_xyz[:, :12])
177+
bias_xyz = BackendTensor.t.to_numpy(bias_xyz)
177178
mask = bias_xyz == 0
178179
masked_arr = np.ma.masked_array(bias_xyz, mask)
179180
mass_points = masked_arr.mean(axis=1)
@@ -200,10 +201,15 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
200201
# Compute LSTSQS in all voxels at the same time
201202
A = edges_normals
202203
b = (A * edges_xyz).sum(axis=2)
203-
204-
term1 = BackendTensor.t.einsum("ijk, ilj->ikl", A, BackendTensor.t.transpose(A, (2, 1)))
204+
205+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
206+
transpose_shape = (2, 1)
207+
else:
208+
transpose_shape = (0, 2,1)
209+
210+
term1 = BackendTensor.t.einsum("ijk, ilj->ikl", A, BackendTensor.t.transpose(A, transpose_shape))
205211
term2 = BackendTensor.t.linalg.inv(term1)
206-
term3 = BackendTensor.t.einsum("ijk,ik->ij", BackendTensor.t.transpose(A, (2, 1)), b)
212+
term3 = BackendTensor.t.einsum("ijk,ik->ij", BackendTensor.t.transpose(A, transpose_shape), b)
207213
vertices = BackendTensor.t.einsum("ijk, ij->ik", term2, term3)
208214

209215
if debug:

gempy_engine/modules/dual_contouring/fancy_triangulation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,10 @@ def check_voxels_exist_next_to_edge(coord_col, edge_vector, _left_right_array_ac
178178

179179
valid_edges_within_extent = code__a_prod_edge * code__b_prod_edge * code__c_prod_edge # * Valid in the sense that there are valid voxels around
180180

181-
code__a_p = mapped_voxel_0[:, valid_edges_within_extent] == 0 # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)
182-
code__b_p = mapped_voxel_1[:, valid_edges_within_extent] == 0 # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)
183-
code__c_p = mapped_voxel_2[:, valid_edges_within_extent] == 0 # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)
181+
from ...core.backend_tensor import BackendTensor
182+
code__a_p = BackendTensor.t.array(mapped_voxel_0[:, valid_edges_within_extent] == 0) # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)
183+
code__b_p = BackendTensor.t.array(mapped_voxel_1[:, valid_edges_within_extent] == 0) # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)
184+
code__c_p = BackendTensor.t.array(mapped_voxel_2[:, valid_edges_within_extent] == 0) # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)
184185

185186
if False:
186187
debug_code_p = code__a_p + code__b_p + code__c_p # (n_voxels, n_voxels - active_voxels_for_given_edge - invalid_edges - edges_at_extent_border)

0 commit comments

Comments
 (0)