-
I am writing a code which maps monthly climatologies from pressure to altitude. This needs to be done on every grid box by itself, therefore I am using xarray.groupby. That works, but it is very slow. I also tried to speed it up this function using flox. from flox.xarray import xarray_reduce
import xarray as xr
def swap_dims(ds: xr.Dataset) -> xr.Dataset:
ds = ds.squeeze()
ds = ds.swap_dims({"pressure": "altitude"})
ds = ds.reset_coords("pressure")
ds = ds.sortby("altitude")
ds = interpolate_to_alt(ds)
return ds
def convert_to_alt(ds: xr.Dataset) -> xr.Dataset:
ds = ds.stack({"stacked_dim": ["latitude_bins", "longitude_bins", "time"]})
ds = ds.chunk(chunks={"pressure":1})
ds = ds.groupby("stacked_dim").map(swap_dims)
# ds = flox.xarray_reduce(ds, by="stacked_dim", func=swap_dims) (Not working flox code)
ds = ds.unstack("stacked_dim")
return ds The stacked chunked ds before the groupy or xarray_reduce looks like this: <xarray.Dataset>
Dimensions: (pressure: 37, stacked_dim: 12458880)
Coordinates:
* pressure (pressure) int32 1 2 3 5 7 10 ... 900 925 950 975 1000
* stacked_dim (stacked_dim) object MultiIndex
* latitude_bins (stacked_dim) float32 90.0 90.0 90.0 ... -90.0 -90.0
* longitude_bins (stacked_dim) float32 0.0 0.0 0.0 ... 359.8 359.8 359.8
* time (stacked_dim) datetime64[ns] 2010-01-01 ... 2010-12-01
Data variables:
specific_humidity (pressure, stacked_dim) float32 dask.array<chunksize=(1, 12458880), meta=np.ndarray>
temperature (pressure, stacked_dim) float32 dask.array<chunksize=(1, 12458880), meta=np.ndarray>
altitude (pressure, stacked_dim) float32 dask.array<chunksize=(1, 12458880), meta=np.ndarray> When I use it like this, I get the error Should the original xarray.groupby work if I just adjust something, is this possible using flox directly, or do I have to look for a totally different multiprocessing option? I also wrote a function using multiprocessing. This looks like: from multiprocessing import Pool, cpu_count
from time import perf_counter
def interpolate_to_alt(
ds: xr.Dataset, pres_var_name: str = vd.pressure.var_name
) -> xr.Dataset:
"""
Interpolates a given xarray Dataset to altitude.
Parameters
----------
ds : xr.Dataset
The xarray Dataset to be interpolated.
pres_var_name : str, optional
The name of the pressure variable. Defaults to 'pressure'.
Returns
-------
xr.Dataset
The interpolated xarray Dataset.
"""
log_vars = [
"pressure",
]
log_vars_ds = list(set(log_vars).intersection(list(ds.variables.keys())))
attrs = {}
shift = {}
for var_ in log_vars_ds:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ds[var_] = np.log(ds[var_]
# Interpolate
ds = ds.interp(altitude=np.arange(0, 80000, 100))
# Antilog
for var_ in log_vars_ds:
ds[var_] = np.exp(ds[var_])
ds[var_].attrs = attrs[var_]
# Get rid of artifacts due to log/antilog
ds["pressure"] = ds["pressure"].round(1)
ds["pressure"].attrs = attrs["pressure"]
return ds
def convert_to_alt(ds: xr.Dataset) -> xr.Dataset:
"""
Converts a given xarray Dataset from pressure to altitude.
Parameters
----------
ds : xr.Dataset
The xarray Dataset to be converted.
Returns
-------
xr.Dataset
The converted xarray Dataset.
"""
# TODO(max): Add a .squeeze() here for better performance?
ds = ds.swap_dims({"pressure": "altitude"}).dropna(
dim="altitude",
subset=["pressure"],
)
ds = ds.reset_coords("pressure")
ds = ds.sortby("altitude")
ds = interpolate_to_alt(ds)
ds = ds.expand_dims(["latitude_bins", "longitude_bins", "time"])
ds = ds.drop("stacked_dim")
return ds
def parallel_stack_func(
ds: xr.Dataset,
stacked_dims: tuple[str, str, str] = (
"latitude_bins",
"longitude_bins",
"time",
),
func: Callable = convert_to_alt,
func_cpu_count: int = -1,
) -> xr.Dataset:
"""
Executes a function in parallel on a given xr.Dataset object.
Parameters
----------
ds : xr.Dataset
The input dataset to perform the function in parallel on.
stacked_dims : tuple[str, str, str], optional
The dimensions to stack the dataset along. Defaults to ("latitude_bins", "longitude_bins", "time").
func : Callable, optional
The function to be executed in parallel. Defaults to `convert_to_alt`.
func_cpu_count : int, optional
The number of CPU processes to use for parallel execution. Defaults to -1, which uses all available CPUs.
Returns
-------
xr.Dataset
The output dataset after performing the function in parallel.
"""
if func_cpu_count == -1:
func_cpu_count = cpu_count()
ds = ds.stack({"stacked_dim": stacked_dims})
ds = ds.load()
ds_list = [ds.isel(stacked_dim=i).squeeze() for i in range(ds.stacked_dim.size)]
with Pool(processes=func_cpu_count) as pool:
res = pool.map(func, ds_list)
out = xr.combine_by_coords(
res,
)
return out It works quite fast (except for the merging), but I am not sure if this is an elegant approach? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 14 replies
-
Hi @mgorfer, do you see any difference when using groupby with a single index instead of a multi-index? Which version of Xarray are you using? Perhaps you are hitting the same issue than #7376, which has been fixed since v2023.06.0? |
Beta Was this translation helpful? Give feedback.
-
As mentioned in xarray-contrib/flox#260, this may be solvable using
import numpy as np
import xarray as xr
# load example dataset and subset (so it's faster)
air = xr.tutorial.open_dataset("air_temperature")
air = air.isel(time=slice(2))
air = air.air.astype(float)
# define the target longitude (use 330, so the result can be double checked)
target = np.array([201, 206, 330])
# example to interpolate one lat/ time combinations
np.interp(target, air.lon, air.air.isel(time=0, lat=0).values)
# define a function to feed to xr.apply_ufunc
# could directly use `np.interp` but to accommodate
# the other manipulations (e.g. log)
def interp(values, coords, target):
out = np.interp(target, coords, values)
return out
# option 1: pass target as a numpy array (no "longitude" coordinates)
xr.apply_ufunc(
interp,
air.air,
air.lon,
kwargs={"target": target},
vectorize=True,
input_core_dims=[["lon"], ["lon"]],
output_core_dims=[["longitude"]]
)
# option 21: pass target as a DataArray (has "longitude" coordinates)
target = xr.DataArray(target, dims="longitude", coords={"longitude": target})
xr.apply_ufunc(
interp,
air.air,
air.lon,
target,
vectorize=True,
input_core_dims=[["lon"], ["lon"], ["longitude"]],
output_core_dims=[["longitude"]]
) |
Beta Was this translation helpful? Give feedback.
-
Use xgcm here: https://xgcm.readthedocs.io/en/latest/transform.html. It does the |
Beta Was this translation helpful? Give feedback.
Use xgcm here: https://xgcm.readthedocs.io/en/latest/transform.html. It does the
apply_ufunc
thing along with numba kernels for fast interpolation.