Skip to content

Commit 9519d0d

Browse files
committed
Cut regions
1 parent 1eae827 commit 9519d0d

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

yt/data_objects/selection_objects/cut_region.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
YTSelectionContainer3D,
99
)
1010
from yt.data_objects.static_output import Dataset
11-
from yt.funcs import iter_fields, validate_object, validate_sequence
11+
from yt.fields.derived_field import DerivedField
12+
from yt.funcs import iter_fields, validate_object
1213
from yt.geometry.selection_routines import points_in_cells
1314
from yt.utilities.exceptions import YTIllDefinedCutRegion
1415
from yt.utilities.on_demand_imports import _scipy
@@ -54,14 +55,14 @@ def __init__(
5455
if locals is None:
5556
locals = {}
5657
validate_object(data_source, YTSelectionContainer)
57-
validate_sequence(conditionals)
58+
conditionals = list(always_iterable(conditionals))
5859
for condition in conditionals:
59-
validate_object(condition, str)
60+
validate_object(condition, (str, DerivedField))
6061
validate_object(ds, Dataset)
6162
validate_object(field_parameters, dict)
6263
validate_object(base_object, YTSelectionContainer)
6364

64-
self.conditionals = list(always_iterable(conditionals))
65+
self.conditionals = conditionals
6566
if isinstance(data_source, YTCutRegion):
6667
# If the source is also a cut region, add its conditionals
6768
# and set the source to be its source.
@@ -82,6 +83,10 @@ def __init__(
8283
def _check_filter_fields(self):
8384
fields = []
8485
for cond in self.conditionals:
86+
if isinstance(cond, DerivedField):
87+
fields.append(cond.name)
88+
continue
89+
8590
for field in re.findall(r"\[([A-Za-z0-9_,.'\"\(\)]+)\]", cond):
8691
fd = field.replace('"', "").replace("'", "")
8792
if "," in fd:
@@ -125,8 +130,11 @@ def blocks(self):
125130
m = m.copy()
126131
with obj._field_parameter_state(self.field_parameters):
127132
for cond in self.conditionals:
128-
ss = eval(cond)
129-
m = np.logical_and(m, ss, m)
133+
if isinstance(cond, DerivedField):
134+
ss = cond(obj)
135+
else:
136+
ss = eval(cond)
137+
m &= ss
130138
if not np.any(m):
131139
continue
132140
yield obj, m
@@ -144,12 +152,15 @@ def _cond_ind(self):
144152
locals["obj"] = obj
145153
with obj._field_parameter_state(self.field_parameters):
146154
for cond in self.conditionals:
147-
res = eval(cond, locals)
155+
if isinstance(cond, DerivedField):
156+
res = cond(obj)
157+
else:
158+
res = eval(cond, locals)
148159
if ind is None:
149160
ind = res
150161
if ind.shape != res.shape:
151162
raise YTIllDefinedCutRegion(self.conditionals)
152-
np.logical_and(res, ind, ind)
163+
ind &= res
153164
return ind
154165

155166
def _part_ind_KDTree(self, ptype):

yt/fields/derived_field.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,11 +524,16 @@ def wrapped(field, data):
524524
other_name = str(other)
525525
other_units = self.ds.get_unit_from_registry(getattr(other, "units", "1"))
526526

527-
if op in (operator.add, operator.sub):
527+
if op in (operator.add, operator.sub, operator.eq):
528528
assert my_units.same_dimensions_as(other_units)
529529
new_units = my_units
530-
else:
530+
elif op in (operator.mul, operator.truediv):
531531
new_units = op(my_units, other_units)
532+
elif op in (operator.le, operator.lt, operator.ge, operator.gt, operator.ne):
533+
# Comparison yield unitless fields
534+
new_units = Unit("1")
535+
else:
536+
raise TypeError(f"Unsupported operator {op} for DerivedField")
532537

533538
return DerivedField(
534539
name=(self.name[0], f"{self.name[1]}_{op.__name__}_{other_name}"),
@@ -593,6 +598,26 @@ def wrapped(field, data):
593598

594599
return inverse_self * other
595600

601+
# Comparison operators
602+
def __leq__(self, other: Union["DerivedField", float]) -> "DerivedField":
603+
return self._operator(other, op=operator.le)
604+
605+
def __lt__(self, other: Union["DerivedField", float]) -> "DerivedField":
606+
return self._operator(other, op=operator.lt)
607+
608+
def __geq__(self, other: Union["DerivedField", float]) -> "DerivedField":
609+
return self._operator(other, op=operator.ge)
610+
611+
def __gt__(self, other: Union["DerivedField", float]) -> "DerivedField":
612+
return self._operator(other, op=operator.gt)
613+
614+
# Somehow, makes yt not work?
615+
# def __eq__(self, other: Union["DerivedField", float]) -> "DerivedField":
616+
# return self._operator(other, op=operator.eq)
617+
618+
def __ne__(self, other: Union["DerivedField", float]) -> "DerivedField":
619+
return self._operator(other, op=operator.ne)
620+
596621

597622
class FieldValidator:
598623
"""

0 commit comments

Comments
 (0)