|  | 
|  | 1 | +import abc | 
| 1 | 2 | import contextlib | 
| 2 | 3 | import inspect | 
| 3 | 4 | import operator | 
| 4 | 5 | import re | 
| 5 | 6 | from collections.abc import Callable, Iterable | 
| 6 |  | -from typing import Optional, Union | 
|  | 7 | +from functools import reduce | 
|  | 8 | +from typing import Optional | 
| 7 | 9 | 
 | 
| 8 | 10 | from more_itertools import always_iterable | 
| 9 | 11 | 
 | 
| @@ -59,7 +61,125 @@ def _DeprecatedFieldFunc(field, data): | 
| 59 | 61 |     return _DeprecatedFieldFunc | 
| 60 | 62 | 
 | 
| 61 | 63 | 
 | 
| 62 |  | -class DerivedField: | 
|  | 64 | +class DerivedFieldBase(abc.ABC): | 
|  | 65 | +    @abc.abstractmethod | 
|  | 66 | +    def __call__(self, field, data): | 
|  | 67 | +        pass | 
|  | 68 | + | 
|  | 69 | +    @abc.abstractmethod | 
|  | 70 | +    def __repr__(self) -> str: | 
|  | 71 | +        pass | 
|  | 72 | + | 
|  | 73 | +    # Multiplication (left and right side) | 
|  | 74 | +    def __mul__(self, other) -> "DerivedFieldCombination": | 
|  | 75 | +        return DerivedFieldCombination([self, other], op=operator.mul) | 
|  | 76 | + | 
|  | 77 | +    def __rmul__(self, other) -> "DerivedFieldCombination": | 
|  | 78 | +        return DerivedFieldCombination([self, other], op=operator.mul) | 
|  | 79 | + | 
|  | 80 | +    # Division (left side) | 
|  | 81 | +    def __truediv__(self, other) -> "DerivedFieldCombination": | 
|  | 82 | +        return DerivedFieldCombination([self, other], op=operator.truediv) | 
|  | 83 | + | 
|  | 84 | +    def __rtruediv__(self, other) -> "DerivedFieldCombination": | 
|  | 85 | +        return DerivedFieldCombination([other, self], op=operator.truediv) | 
|  | 86 | + | 
|  | 87 | +    # Addition (left and right side) | 
|  | 88 | +    def __add__(self, other) -> "DerivedFieldCombination": | 
|  | 89 | +        return DerivedFieldCombination([self, other], op=operator.add) | 
|  | 90 | + | 
|  | 91 | +    def __radd__(self, other) -> "DerivedFieldCombination": | 
|  | 92 | +        return DerivedFieldCombination([self, other], op=operator.add) | 
|  | 93 | + | 
|  | 94 | +    # Subtraction (left and right side) | 
|  | 95 | +    def __sub__(self, other) -> "DerivedFieldCombination": | 
|  | 96 | +        return DerivedFieldCombination([self, other], op=operator.sub) | 
|  | 97 | + | 
|  | 98 | +    def __rsub__(self, other) -> "DerivedFieldCombination": | 
|  | 99 | +        return DerivedFieldCombination([other, self], op=operator.sub) | 
|  | 100 | + | 
|  | 101 | +    # Unary minus | 
|  | 102 | +    def __neg__(self) -> "DerivedFieldCombination": | 
|  | 103 | +        return DerivedFieldCombination([self], op=operator.neg) | 
|  | 104 | + | 
|  | 105 | +    # Comparison operators | 
|  | 106 | +    def __leq__(self, other) -> "DerivedFieldCombination": | 
|  | 107 | +        return DerivedFieldCombination([self, other], op=operator.le) | 
|  | 108 | + | 
|  | 109 | +    def __lt__(self, other) -> "DerivedFieldCombination": | 
|  | 110 | +        return DerivedFieldCombination([self, other], op=operator.lt) | 
|  | 111 | + | 
|  | 112 | +    def __geq__(self, other) -> "DerivedFieldCombination": | 
|  | 113 | +        return DerivedFieldCombination([self, other], op=operator.ge) | 
|  | 114 | + | 
|  | 115 | +    def __gt__(self, other) -> "DerivedFieldCombination": | 
|  | 116 | +        return DerivedFieldCombination([self, other], op=operator.gt) | 
|  | 117 | + | 
|  | 118 | +    # def __eq__(self, other) -> "DerivedFieldCombination": | 
|  | 119 | +    #     return DerivedFieldCombination([self, other], op=operator.eq) | 
|  | 120 | + | 
|  | 121 | +    def __ne__(self, other) -> "DerivedFieldCombination": | 
|  | 122 | +        return DerivedFieldCombination([self, other], op=operator.ne) | 
|  | 123 | + | 
|  | 124 | + | 
|  | 125 | +class DerivedFieldCombination(DerivedFieldBase): | 
|  | 126 | +    sampling_type: str | None | 
|  | 127 | +    terms: list | 
|  | 128 | +    op: Callable | None | 
|  | 129 | + | 
|  | 130 | +    def __init__(self, terms: list, op=None): | 
|  | 131 | +        if not terms: | 
|  | 132 | +            raise ValueError("DerivedFieldCombination requires at least one term.") | 
|  | 133 | + | 
|  | 134 | +        # Make sure all terms have the same sampling type | 
|  | 135 | +        sampling_types = set() | 
|  | 136 | +        for term in terms: | 
|  | 137 | +            if isinstance(term, DerivedField): | 
|  | 138 | +                sampling_types.add(term.sampling_type) | 
|  | 139 | + | 
|  | 140 | +        if len(sampling_types) > 1: | 
|  | 141 | +            raise ValueError( | 
|  | 142 | +                "All terms in a DerivedFieldCombination must " | 
|  | 143 | +                "have the same sampling type." | 
|  | 144 | +            ) | 
|  | 145 | +        self.sampling_type = sampling_types.pop() if sampling_types else None | 
|  | 146 | +        self.terms = terms | 
|  | 147 | +        self.op = op | 
|  | 148 | + | 
|  | 149 | +    def __call__(self, field, data): | 
|  | 150 | +        """ | 
|  | 151 | +        Return the value of the field in a given data object. | 
|  | 152 | +        """ | 
|  | 153 | +        qties = [] | 
|  | 154 | +        for term in self.terms: | 
|  | 155 | +            if isinstance(term, DerivedField): | 
|  | 156 | +                qties.append(data[term.name]) | 
|  | 157 | +            elif isinstance(term, DerivedFieldCombination): | 
|  | 158 | +                qties.append(term(field, data)) | 
|  | 159 | +            else: | 
|  | 160 | +                qties.append(term) | 
|  | 161 | + | 
|  | 162 | +        if len(qties) == 1: | 
|  | 163 | +            return self.op(qties[0]) | 
|  | 164 | +        else: | 
|  | 165 | +            return reduce(self.op, qties) | 
|  | 166 | + | 
|  | 167 | +    def __repr__(self): | 
|  | 168 | +        return f"DerivedFieldCombination(terms={self.terms!r}, op={self.op!r})" | 
|  | 169 | + | 
|  | 170 | +    def getDependentFields(self): | 
|  | 171 | +        fields = [] | 
|  | 172 | +        for term in self.terms: | 
|  | 173 | +            if isinstance(term, DerivedField): | 
|  | 174 | +                fields.append(term.name) | 
|  | 175 | +            elif isinstance(term, DerivedFieldCombination): | 
|  | 176 | +                fields.extend(term.getDependentFields()) | 
|  | 177 | +            else: | 
|  | 178 | +                continue | 
|  | 179 | +        return fields | 
|  | 180 | + | 
|  | 181 | + | 
|  | 182 | +class DerivedField(DerivedFieldBase): | 
| 63 | 183 |     """ | 
| 64 | 184 |     This is the base class used to describe a cell-by-cell derived field. | 
| 65 | 185 | 
 | 
| @@ -277,11 +397,11 @@ def __call__(self, data): | 
| 277 | 397 |         """Return the value of the field in a given *data* object.""" | 
| 278 | 398 |         self.check_available(data) | 
| 279 | 399 |         original_fields = data.keys()  # Copy | 
| 280 |  | -        if self._function is NullFunc: | 
| 281 |  | -            raise RuntimeError( | 
| 282 |  | -                "Something has gone terribly wrong, _function is NullFunc " | 
| 283 |  | -                + f"for {self.name}" | 
| 284 |  | -            ) | 
|  | 400 | +        # if self._function is NullFunc: | 
|  | 401 | +        #     raise RuntimeError( | 
|  | 402 | +        #         "Something has gone terribly wrong, _function is NullFunc " | 
|  | 403 | +        #         + f"for {self.name}" | 
|  | 404 | +        #     ) | 
| 285 | 405 |         with self.unit_registry(data): | 
| 286 | 406 |             dd = self._function(self, data) | 
| 287 | 407 |         for field_name in data.keys(): | 
| @@ -499,128 +619,6 @@ def __copy__(self): | 
| 499 | 619 |             nodal_flag=self.nodal_flag, | 
| 500 | 620 |         ) | 
| 501 | 621 | 
 | 
| 502 |  | -    def _operator( | 
| 503 |  | -        self, other: Union["DerivedField", float], op: Callable | 
| 504 |  | -    ) -> "DerivedField": | 
| 505 |  | -        my_units = self.ds.get_unit_from_registry(self.units) | 
| 506 |  | -        if isinstance(other, DerivedField): | 
| 507 |  | -            if self.sampling_type != other.sampling_type: | 
| 508 |  | -                raise TypeError( | 
| 509 |  | -                    f"Cannot {op} fields with different sampling types: " | 
| 510 |  | -                    f"{self.sampling_type} and {other.sampling_type}" | 
| 511 |  | -                ) | 
| 512 |  | - | 
| 513 |  | -            def wrapped(field, data): | 
| 514 |  | -                return op(self(data), other(data)) | 
| 515 |  | - | 
| 516 |  | -            other_name = other.name[1] | 
| 517 |  | -            other_units = self.ds.get_unit_from_registry(other.units) | 
| 518 |  | - | 
| 519 |  | -        else: | 
| 520 |  | -            # Special case when passing (value, "unit") tuple | 
| 521 |  | -            if isinstance(other, tuple) and len(other) == 2: | 
| 522 |  | -                other = self.ds.quan(*other) | 
| 523 |  | - | 
| 524 |  | -            def wrapped(field, data): | 
| 525 |  | -                return op(self(data), other) | 
| 526 |  | - | 
| 527 |  | -            other_name = str(other) | 
| 528 |  | -            other_units = getattr(other, "units", self.ds.get_unit_from_registry("1")) | 
| 529 |  | - | 
| 530 |  | -        if op in (operator.add, operator.sub, operator.eq): | 
| 531 |  | -            assert my_units.same_dimensions_as(other_units) | 
| 532 |  | -            new_units = my_units | 
| 533 |  | -        elif op in (operator.mul, operator.truediv): | 
| 534 |  | -            new_units = op(my_units, other_units) | 
| 535 |  | -        elif op in (operator.le, operator.lt, operator.ge, operator.gt, operator.ne): | 
| 536 |  | -            # Comparison yield unitless fields | 
| 537 |  | -            new_units = Unit("1") | 
| 538 |  | -        else: | 
| 539 |  | -            raise TypeError(f"Unsupported operator {op} for DerivedField") | 
| 540 |  | - | 
| 541 |  | -        return DerivedField( | 
| 542 |  | -            name=(self.name[0], f"{self.name[1]}_{op.__name__}_{other_name}"), | 
| 543 |  | -            sampling_type=self.sampling_type, | 
| 544 |  | -            function=wrapped, | 
| 545 |  | -            units=new_units, | 
| 546 |  | -            ds=self.ds, | 
| 547 |  | -        ) | 
| 548 |  | - | 
| 549 |  | -    # Multiplication (left and right side) | 
| 550 |  | -    def __mul__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 551 |  | -        return self._operator(other, op=operator.mul) | 
| 552 |  | - | 
| 553 |  | -    def __rmul__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 554 |  | -        return self._operator(other, op=operator.mul) | 
| 555 |  | - | 
| 556 |  | -    # Division (left side) | 
| 557 |  | -    def __truediv__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 558 |  | -        return self._operator(other, op=operator.truediv) | 
| 559 |  | - | 
| 560 |  | -    # Addition (left and right side) | 
| 561 |  | -    def __add__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 562 |  | -        return self._operator(other, op=operator.add) | 
| 563 |  | - | 
| 564 |  | -    def __radd__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 565 |  | -        return self._operator(other, op=operator.add) | 
| 566 |  | - | 
| 567 |  | -    # Subtraction (left and right side) | 
| 568 |  | -    def __sub__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 569 |  | -        return self._operator(other, op=operator.sub) | 
| 570 |  | - | 
| 571 |  | -    def __rsub__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 572 |  | -        return self._operator(-other, op=operator.add) | 
| 573 |  | - | 
| 574 |  | -    # Unary minus | 
| 575 |  | -    def __neg__(self) -> "DerivedField": | 
| 576 |  | -        def wrapped(field, data): | 
| 577 |  | -            return -self(data) | 
| 578 |  | - | 
| 579 |  | -        return DerivedField( | 
| 580 |  | -            name=(self.name[0], f"neg_{self.name[1]}"), | 
| 581 |  | -            sampling_type=self.sampling_type, | 
| 582 |  | -            function=wrapped, | 
| 583 |  | -            units=self.units, | 
| 584 |  | -            ds=self.ds, | 
| 585 |  | -        ) | 
| 586 |  | - | 
| 587 |  | -    # Division (right side, a bit more complex) | 
| 588 |  | -    def __rtruediv__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 589 |  | -        units = self.ds.get_unit_from_registry(self.units) | 
| 590 |  | - | 
| 591 |  | -        def wrapped(field, data): | 
| 592 |  | -            return 1 / self(data) | 
| 593 |  | - | 
| 594 |  | -        inverse_self = DerivedField( | 
| 595 |  | -            name=(self.name[0], f"inverse_{self.name[1]}"), | 
| 596 |  | -            sampling_type=self.sampling_type, | 
| 597 |  | -            function=wrapped, | 
| 598 |  | -            units=units**-1, | 
| 599 |  | -            ds=self.ds, | 
| 600 |  | -        ) | 
| 601 |  | - | 
| 602 |  | -        return inverse_self * other | 
| 603 |  | - | 
| 604 |  | -    # Comparison operators | 
| 605 |  | -    def __leq__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 606 |  | -        return self._operator(other, op=operator.le) | 
| 607 |  | - | 
| 608 |  | -    def __lt__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 609 |  | -        return self._operator(other, op=operator.lt) | 
| 610 |  | - | 
| 611 |  | -    def __geq__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 612 |  | -        return self._operator(other, op=operator.ge) | 
| 613 |  | - | 
| 614 |  | -    def __gt__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 615 |  | -        return self._operator(other, op=operator.gt) | 
| 616 |  | - | 
| 617 |  | -    # Somehow, makes yt not work? | 
| 618 |  | -    # def __eq__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 619 |  | -    #     return self._operator(other, op=operator.eq) | 
| 620 |  | - | 
| 621 |  | -    def __ne__(self, other: Union["DerivedField", float]) -> "DerivedField": | 
| 622 |  | -        return self._operator(other, op=operator.ne) | 
| 623 |  | - | 
| 624 | 622 | 
 | 
| 625 | 623 | class FieldValidator: | 
| 626 | 624 |     """ | 
|  | 
0 commit comments