Skip to content

Commit 7485263

Browse files
committed
"Updated numpy methods to PyTorch methods in various files"
In an effort to fully transition from numpy to PyTorch, various numpy methods have been replaced with their PyTorch counterparts. This change was made to improve the overall performance and compatibility of the code. The affected files include 'backend_tensor.py', 'data_preprocess_interface.py', '_vectors_preparation.py', and '_multi_scalar_field_manager.py'. The numpy methods that were replaced include 'array', 'repeat', 'zeros', 'eye', and 'vstack'.
1 parent d4c2060 commit 7485263

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

gempy_engine/API/interp_single/_multi_scalar_field_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def _grab_stack_fault_data(_all_stack_values_block, _interpolation_input_i, _sta
6060
all_scalar_fields_outputs: List[ScalarFieldOutput | None] = [None] * stack_structure.n_stacks
6161

6262
xyz_to_interpolate_size: int = root_interpolation_input.grid.len_all_grids + root_interpolation_input.surface_points.n_points
63-
all_stack_values_block: np.ndarray = np.zeros((stack_structure.n_stacks, xyz_to_interpolate_size), dtype=BackendTensor.dtype) # * Used for faults
63+
all_stack_values_block: np.ndarray = BackendTensor.t.zeros(
64+
(stack_structure.n_stacks, xyz_to_interpolate_size),
65+
dtype=BackendTensor.dtype_obj) # * Used for faults
6466

6567
for i in range(stack_structure.n_stacks):
6668
stack_structure.stack_number = i

gempy_engine/core/backend_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ def _repeat(tensor, n_repeats, axis=None):
167167
def _array(array_like, dtype=None):
168168
if isinstance(dtype, str):
169169
dtype = getattr(torch, dtype)
170-
171-
return torch.tensor(array_like, dtype=dtype)
170+
if isinstance(array_like, torch.Tensor):
171+
return array_like
172+
else:
173+
return torch.tensor(array_like, dtype=dtype)
172174

173175
def _concatenate(tensors, axis=0, dtype=None):
174176
# Switch if tensor is numpy array or a torch tensor

gempy_engine/modules/data_preprocess/data_preprocess_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def prepare_grid(grid: np.ndarray, surface_points: SurfacePoints) -> np.ndarray:
2727
def prepare_faults(faults_values_on_sp: np.ndarray, tensors_structure: TensorsStructure) -> Tuple[ndarray, ndarray]:
2828

2929
partitions_bool = tensors_structure.partitions_bool
30-
number_repetitions = tensors_structure.number_of_points_per_surface - 1
30+
number_repetitions = bt.t.array(tensors_structure.number_of_points_per_surface - 1)
3131

3232
ref_points = faults_values_on_sp[:, partitions_bool]
3333

34-
ref_matrix_val_repeated = np.repeat(ref_points, number_repetitions, 1)
34+
ref_matrix_val_repeated = bt.t.repeat(ref_points, number_repetitions, 1)
3535
rest_matrix_val = faults_values_on_sp[:, ~partitions_bool]
3636

3737
return ref_matrix_val_repeated, rest_matrix_val

gempy_engine/modules/kernel_constructor/_vectors_preparation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def _assembly_fault_internals(faults_val, options, ori_size):
226226
def _assembler(matrix_val, ori_size_: int, uni_drift_size: int): # TODO: This function (probably)needs to be extracted to _kernel_constructors
227227
n_uni_eq = uni_drift_size # * Number of equations. This should be how many faults are active
228228
n_faults = matrix_val.shape[1] # TODO [ ]: We are going to have to tweak this for multiple faults
229-
z = np.zeros((ori_size_, n_faults), dtype=BackendTensor.dtype)
230-
z2 = np.zeros((n_uni_eq, n_faults), dtype=BackendTensor.dtype)
231-
z3 = np.eye(n_faults, dtype=BackendTensor.dtype)
229+
z = BackendTensor.t.zeros((ori_size_, n_faults), dtype=BackendTensor.dtype_obj)
230+
z2 = BackendTensor.t.zeros((n_uni_eq, n_faults), dtype=BackendTensor.dtype_obj)
231+
z3 = BackendTensor.t.eye(n_faults, dtype=BackendTensor.dtype_obj)
232232
# Degree 1
233-
return np.vstack((z, matrix_val, z2, z3))
233+
return BackendTensor.t.vstack((z, matrix_val, z2, z3))
234234

235235
ref_matrix_val = faults_val.fault_values_ref
236236
rest_matrix_val = faults_val.fault_values_rest

0 commit comments

Comments
 (0)