Skip to content

Commit 55db760

Browse files
committed
Implement all basic operations
1 parent ff406c4 commit 55db760

File tree

1 file changed

+49
-6
lines changed

1 file changed

+49
-6
lines changed

yt/fields/derived_field.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import contextlib
22
import inspect
3+
import operator
34
import re
45
from collections.abc import Callable, Iterable
56
from typing import Optional, Union
67

7-
import numpy as np
88
from more_itertools import always_iterable
99

1010
import yt.units.dimensions as ytdims
@@ -524,7 +524,7 @@ def wrapped(field, data):
524524
other_name = str(other)
525525
other_units = self.ds.get_unit_from_registry(getattr(other, "units", "1"))
526526

527-
if op in (np.add, np.subtract):
527+
if op in (operator.add, operator.sub):
528528
assert my_units.same_dimensions_as(other_units)
529529
new_units = my_units
530530
else:
@@ -538,17 +538,60 @@ def wrapped(field, data):
538538
ds=self.ds,
539539
)
540540

541+
# Multiplication (left and right side)
541542
def __mul__(self, other: Union["DerivedField", float]) -> "DerivedField":
542-
return self._operator(other, op=np.multiply)
543+
return self._operator(other, op=operator.mul)
543544

545+
def __rmul__(self, other: Union["DerivedField", float]) -> "DerivedField":
546+
return self._operator(other, op=operator.mul)
547+
548+
# Division (left side)
544549
def __truediv__(self, other: Union["DerivedField", float]) -> "DerivedField":
545-
return self._operator(other, op=np.divide)
550+
return self._operator(other, op=operator.truediv)
546551

552+
# Addition (left and right side)
547553
def __add__(self, other: Union["DerivedField", float]) -> "DerivedField":
548-
return self._operator(other, op=np.add)
554+
return self._operator(other, op=operator.add)
555+
556+
def __radd__(self, other: Union["DerivedField", float]) -> "DerivedField":
557+
return self._operator(other, op=operator.add)
549558

559+
# Subtraction (left and right side)
550560
def __sub__(self, other: Union["DerivedField", float]) -> "DerivedField":
551-
return self._operator(other, op=np.subtract)
561+
return self._operator(other, op=operator.sub)
562+
563+
def __rsub__(self, other: Union["DerivedField", float]) -> "DerivedField":
564+
return self._operator(-other, op=operator.add)
565+
566+
# Unary minus
567+
def __neg__(self) -> "DerivedField":
568+
def wrapped(field, data):
569+
return -self(data)
570+
571+
return DerivedField(
572+
name=(self.name[0], f"neg_{self.name[1]}"),
573+
sampling_type=self.sampling_type,
574+
function=wrapped,
575+
units=self.units,
576+
ds=self.ds,
577+
)
578+
579+
# Division (right side, a bit more complex)
580+
def __rtruediv__(self, other: Union["DerivedField", float]) -> "DerivedField":
581+
units = self.ds.get_unit_from_registry(self.units)
582+
583+
def wrapped(field, data):
584+
return 1 / self(data)
585+
586+
inverse_self = DerivedField(
587+
name=(self.name[0], f"inverse_{self.name[1]}"),
588+
sampling_type=self.sampling_type,
589+
function=wrapped,
590+
units=units**-1,
591+
ds=self.ds,
592+
)
593+
594+
return inverse_self * other
552595

553596

554597
class FieldValidator:

0 commit comments

Comments
 (0)