Skip to content

Commit 6779a49

Browse files
committed
[WIP] Trying to do the triangulation in Torch
1 parent 3dadf4b commit 6779a49

File tree

7 files changed

+61
-16
lines changed

7 files changed

+61
-16
lines changed

gempy_engine/API/dual_contouring/_dual_contouring.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,30 @@ def compute_dual_contouring(dc_data_per_stack: DualContouringData, left_right_co
5454
edges_normals[:] = np.nan
5555
edges_normals[valid_edges] = dc_data_per_stack.gradients[slice_object]
5656

57-
with warnings.catch_warnings():
58-
warnings.simplefilter("ignore", category=RuntimeWarning)
59-
voxel_normal = np.nanmean(edges_normals, axis=1)
60-
voxel_normal = voxel_normal[(~np.isnan(voxel_normal).any(axis=1))] # drop nans
57+
if LEGACY:=False:
58+
with warnings.catch_warnings():
59+
warnings.simplefilter("ignore", category=RuntimeWarning)
60+
voxel_normal = np.nanmean(edges_normals, axis=1)
61+
voxel_normal = voxel_normal[(~np.isnan(voxel_normal).any(axis=1))] # drop nans
62+
pass
63+
else:
64+
# Assuming edges_normals is a PyTorch tensor
65+
nan_mask = BackendTensor.t.isnan(edges_normals)
66+
valid_count = (~nan_mask).sum(dim=1)
67+
68+
# Replace NaNs with 0 for sum calculation
69+
safe_normals = edges_normals.clone()
70+
safe_normals[nan_mask] = 0
71+
72+
# Compute the sum of non-NaN elements
73+
sum_normals = BackendTensor.t.sum(safe_normals, 1)
74+
75+
# Calculate the mean, avoiding division by zero
76+
voxel_normal = sum_normals / valid_count.clamp(min=1)
77+
78+
# Remove rows where all elements were NaN (and hence valid_count is 0)
79+
voxel_normal = voxel_normal[valid_count > 0]
80+
6181

6282
valid_voxels = dc_data_per_surface.valid_voxels
6383
indices = triangulate(

gempy_engine/API/interp_single/_interp_single_feature.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from gempy_engine.config import AvailableBackends
6+
from gempy_engine.config import AvailableBackends, COMPUTE_GRADIENTS
77
from ...core.backend_tensor import BackendTensor
88
from ._interp_scalar_field import interpolate_scalar_field
99
from ...core.data import SurfacePoints, SurfacePointsInternals, Orientations, OrientationsInternals, TensorsStructure
@@ -26,7 +26,10 @@ def interpolate_feature(interpolation_input: InterpolationInput,
2626
external_segment_funct: Optional[Callable[[np.ndarray], float]] = None,
2727
clean_buffer: bool = True) -> ScalarFieldOutput:
2828

29-
grid = copy.deepcopy(interpolation_input.grid)
29+
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and COMPUTE_GRADIENTS is False:
30+
grid = copy.deepcopy(interpolation_input.grid)
31+
else:
32+
grid = interpolation_input.grid
3033

3134
# region Interpolate scalar field
3235
xyz = solver_input.xyz_to_interpolate

gempy_engine/API/interp_single/_octree_generation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def interpolate_on_octree(interpolation_input: InterpolationInput, options: Inte
2121

2222
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and COMPUTE_GRADIENTS is False:
2323
temp_interpolation_input = copy.deepcopy(interpolation_input)
24+
else:
25+
temp_interpolation_input = interpolation_input
2426

2527
# * Interpolate - centers
2628
output_0_centers: List[InterpOutput] = interpolate_all_fields(temp_interpolation_input, options, data_shape) # interpolate - centers

gempy_engine/API/interp_single/interp_features.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import List
33

44
import gempy_engine.core.data.tensors_structure
5+
from gempy_engine.config import AvailableBackends, COMPUTE_GRADIENTS
6+
from ...core.backend_tensor import BackendTensor
57
from ...core.data.grid import Grid
68
from ...core import data
79
from ...core.data import InterpolationOptions
@@ -45,8 +47,12 @@ def interpolate_n_octree_levels(interpolation_input: InterpolationInput, options
4547

4648
def interpolate_all_fields_no_octree(interpolation_input: InterpolationInput, options: InterpolationOptions,
4749
data_descriptor: InputDataDescriptor) -> List[InterpOutput]:
48-
interpolation_input = copy.deepcopy(interpolation_input)
49-
return ms.interpolate_all_fields(interpolation_input, options, data_descriptor)
50+
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and COMPUTE_GRADIENTS is False:
51+
temp_interpolation_input = copy.deepcopy(interpolation_input)
52+
else:
53+
temp_interpolation_input = interpolation_input
54+
55+
return ms.interpolate_all_fields(temp_interpolation_input, options, data_descriptor)
5056

5157

5258
# region testing

gempy_engine/core/backend_tensor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,19 @@ def _array(array_like, dtype=None):
169169
dtype = getattr(torch, dtype)
170170

171171
return torch.tensor(array_like, dtype=dtype)
172-
172+
173+
def _concatenate(tensors, axis=0, dtype=None):
174+
# Switch if tensor is numpy array or a torch tensor
175+
match type(tensors[0]):
176+
case numpy.ndarray:
177+
return numpy.concatenate(tensors, axis=axis)
178+
case torch.Tensor:
179+
return torch.cat(tensors, dim=axis)
180+
181+
def _transpose(tensor, axes=None):
182+
return tensor.transpose(axes[0], axes[1])
183+
184+
173185
cls.tfnp.sum = _sum
174186
cls.tfnp.repeat = _repeat
175187
cls.tfnp.expand_dims = lambda tensor, axis: tensor
@@ -182,7 +194,8 @@ def _array(array_like, dtype=None):
182194
cls.tfnp.rint = lambda tensor: tensor.round().type(torch.int32)
183195
cls.tfnp.vstack = lambda tensors: torch.cat(tensors, dim=0)
184196
cls.tfnp.copy = lambda tensor: tensor.clone()
185-
197+
cls.tfnp.concatenate = _concatenate
198+
cls.tfnp.transpose = _transpose
186199

187200

188201
@classmethod

gempy_engine/core/data/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def values(self) -> np.ndarray:
4646
if self.geophysics_grid is not None:
4747
values.append(self.geophysics_grid.values)
4848

49-
values_array = np.concatenate(values, dtype=BackendTensor.dtype)
49+
values_array = BackendTensor.t.concatenate(values, dtype=BackendTensor.dtype)
5050
values_array = BackendTensor.t.array(values_array)
5151

5252
return values_array

gempy_engine/modules/dual_contouring/dual_contouring_interface.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
173173
bias_xyz[isclose] = np.nan # np zero values to nans
174174
mass_points = np.nanmean(bias_xyz, axis=1) # Mean ignoring nans
175175
else: # ? This is actually doing something
176-
bias_xyz = np.copy(edges_xyz[:, :12])
176+
bias_xyz = BackendTensor.t.copy(edges_xyz[:, :12]).detach().numpy()
177177
mask = bias_xyz == 0
178178
masked_arr = np.ma.masked_array(bias_xyz, mask)
179179
mass_points = masked_arr.mean(axis=1)
@@ -200,10 +200,11 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
200200
# Compute LSTSQS in all voxels at the same time
201201
A = edges_normals
202202
b = (A * edges_xyz).sum(axis=2)
203-
term1 = np.einsum("ijk, ilj->ikl", A, np.transpose(A, (0, 2, 1)))
204-
term2 = np.linalg.inv(term1)
205-
term3 = np.einsum("ijk,ik->ij", np.transpose(A, (0, 2, 1)), b)
206-
vertices = np.einsum("ijk, ij->ik", term2, term3)
203+
204+
term1 = BackendTensor.t.einsum("ijk, ilj->ikl", A, BackendTensor.t.transpose(A, (2, 1)))
205+
term2 = BackendTensor.t.linalg.inv(term1)
206+
term3 = BackendTensor.t.einsum("ijk,ik->ij", BackendTensor.t.transpose(A, (2, 1)), b)
207+
vertices = BackendTensor.t.einsum("ijk, ij->ik", term2, term3)
207208

208209
if debug:
209210
dc_data_per_stack.bias_center_mass = edges_xyz[:, 12:].reshape(-1, 3)

0 commit comments

Comments
 (0)