Skip to content

Commit cb1d649

Browse files
committed
Add register_fields method to register or override on-disk field properties.
1 parent 760c7a7 commit cb1d649

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

yt/data_objects/static_output.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class Dataset(abc.ABC):
180180
_ionization_label_format = "roman_numeral"
181181
_determined_fields: dict[str, list[FieldKey]] | None = None
182182
fields_detected = False
183+
_registered_fields = None
183184

184185
# these are set in self._parse_parameter_file()
185186
domain_left_edge = MutableAttribute(True)
@@ -1716,6 +1717,44 @@ def quan(self):
17161717
self._quan = functools.partial(YTQuantity, registry=self.unit_registry)
17171718
return self._quan
17181719

1720+
def register_field(
1721+
self, name, units=None, aliases=None, display_name=None
1722+
):
1723+
"""
1724+
Register properties for an on-disk field.
1725+
1726+
Register or override units, aliases, or the display name for an
1727+
on-disk field.
1728+
1729+
Note, this must be called immediately after yt.load (i.e., before
1730+
the list of fields is generated).
1731+
1732+
Parameters
1733+
----------
1734+
1735+
name : str
1736+
name of the field
1737+
units : str
1738+
units of the field
1739+
aliases : list
1740+
a list of alias names
1741+
display_name: str
1742+
the name used in plots
1743+
1744+
"""
1745+
if self._registered_fields is None:
1746+
self._registered_fields = {}
1747+
1748+
entry = {}
1749+
if units is not None:
1750+
entry["units"] = units
1751+
if aliases is not None:
1752+
entry["alias"] = alias
1753+
if display_name is not None:
1754+
entry["display_name"] = display_name
1755+
1756+
self._registered_fields[name] = entry
1757+
17191758
def add_field(
17201759
self, name, function, sampling_type, *, force_override=False, **kwargs
17211760
):

yt/fields/field_info_container.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,12 @@ def setup_particle_fields(self, ptype, ftype="gas", num_neighbors=64):
150150
raise RuntimeError
151151
if field[0] not in self.ds.particle_types:
152152
continue
153+
153154
units = self.ds.field_units.get(field, None)
155+
rfields = getattr(self.ds, "_registered_fields", {})
156+
if entry := rfields.get(field):
157+
units = entry.get("units", units)
158+
154159
if units is None:
155160
try:
156161
units = ytcfg.get("fields", *field, "units")
@@ -256,6 +261,7 @@ def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
256261
raise RuntimeError
257262
if field[0] in self.ds.particle_types:
258263
continue
264+
259265
args = known_other_fields.get(field[1], None)
260266
if args is not None:
261267
units, aliases, display_name = args
@@ -273,6 +279,14 @@ def setup_fluid_aliases(self, ftype: FieldType = "gas") -> None:
273279
# field *name* is in there, then the field *tuple*.
274280
units = self.ds.field_units.get(field[1], units)
275281
units = self.ds.field_units.get(field, units)
282+
283+
# allow user to override with call to ds.register_fields
284+
rfields = getattr(self.ds, "_registered_fields", {})
285+
if entry := rfields.get(field):
286+
units = entry.get("units", units)
287+
aliases = entry.get("aliases", aliases)
288+
display_name = entry.get("display_name", display_name)
289+
276290
self.add_output_field(
277291
field, sampling_type="cell", units=units, display_name=display_name
278292
)

0 commit comments

Comments
 (0)