Skip to content

Commit d448c74

Browse files
committed
Update total gradient loop
1 parent bf7a944 commit d448c74

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

SimPEG/regularization/base.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class BaseRegularization(BaseObjectiveFunction):
3030
:param weights: Weight multipliers to customize the least-squares function.
3131
"""
3232

33-
_model = None
34-
3533
def __init__(
3634
self,
3735
mesh: RegularizationMesh | BaseMesh,
@@ -51,6 +49,8 @@ def __init__(
5149
f"Value of type {type(mesh)} provided."
5250
)
5351

52+
self._model = None
53+
self._parent = None
5454
self._regularization_mesh = mesh
5555
self._weights = {}
5656

@@ -136,6 +136,23 @@ def mapping(self, mapping: maps.IdentityMap):
136136
)
137137
self._mapping = mapping
138138

139+
@property
140+
def parent(self):
141+
"""
142+
The parent objective function
143+
"""
144+
return self._parent
145+
146+
@parent.setter
147+
def parent(self, parent):
148+
combo_class = ComboObjectiveFunction
149+
if not isinstance(parent, combo_class):
150+
raise TypeError(
151+
f"Invalid parent of type '{parent.__class__.__name__}'. "
152+
f"Parent must be a {combo_class.__name__}."
153+
)
154+
self._parent = parent
155+
139156
@property
140157
def units(self) -> str | None:
141158
"""Specify the model units. Special care given to 'radian' values"""
@@ -555,10 +572,7 @@ def _cell_distances(self):
555572
"""
556573
Distances between cell centers for the cell center difference.
557574
"""
558-
if self.__cell_distances is None:
559-
self.__cell_distances = 1.0 / np.max(self.cell_gradient, axis=1).data
560-
561-
return self.__cell_distances
575+
return getattr(self.regularization_mesh, f"cell_distances_{self.orientation}")
562576

563577
@property
564578
def orientation(self):
@@ -784,6 +798,10 @@ def __init__(
784798
else:
785799
objfcts = kwargs.pop("objfcts")
786800
super().__init__(objfcts=objfcts, **kwargs)
801+
802+
for fun in objfcts:
803+
fun.parent = self
804+
787805
self.mapping = mapping
788806
self.reference_model = reference_model
789807
self.reference_model_in_smooth = reference_model_in_smooth

SimPEG/regularization/sparse.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -140,30 +140,12 @@ def update_weights(self, m):
140140
"""
141141
Compute and store the irls weights.
142142
"""
143-
if self.gradient_type == "total":
144-
delta_m = self.mapping * self._delta_m(m)
145-
f_m = np.zeros_like(delta_m)
146-
for ii, comp in enumerate("xyz"):
147-
if self.regularization_mesh.dim > ii:
148-
dm = (
149-
getattr(self.regularization_mesh, f"cell_gradient_{comp}")
150-
* delta_m
151-
)
152-
153-
if self.units is not None and self.units.lower() == "radian":
154-
Ave = getattr(self.regularization_mesh, f"aveCC2F{comp}")
155-
length_scales = Ave * (
156-
self.regularization_mesh.Pac.T
157-
* self.regularization_mesh.mesh.h_gridded[:, ii]
158-
)
159-
dm = (
160-
utils.mat_utils.coterminal(dm * length_scales)
161-
/ length_scales
162-
)
163-
164-
f_m += np.abs(
165-
getattr(self.regularization_mesh, f"aveF{comp}2CC") * dm
166-
)
143+
if self.gradient_type == "total" and self.parent is not None:
144+
f_m = np.zeros(self.regularization_mesh.nC)
145+
for obj in self.parent.objfcts:
146+
if isinstance(obj, SparseSmoothness):
147+
avg = getattr(self.regularization_mesh, f"aveF{obj.orientation}2CC")
148+
f_m += np.abs(avg * obj.f_m(m))
167149

168150
f_m = getattr(self.regularization_mesh, f"aveCC2F{self.orientation}") * f_m
169151

0 commit comments

Comments
 (0)