|  | 
|  | 1 | +"""Computing spatial relationships between points, such as distances.""" | 
|  | 2 | + | 
|  | 3 | +import itertools | 
|  | 4 | +from typing import Literal | 
|  | 5 | + | 
|  | 6 | +import xarray as xr | 
|  | 7 | +from scipy.spatial.distance import cdist | 
|  | 8 | + | 
|  | 9 | +from movement.utils.logging import logger | 
|  | 10 | +from movement.validators.arrays import validate_dims_coords | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +def _cdist( | 
|  | 14 | +    a: xr.DataArray, | 
|  | 15 | +    b: xr.DataArray, | 
|  | 16 | +    dim: Literal["individuals", "keypoints"], | 
|  | 17 | +    metric: str | None = "euclidean", | 
|  | 18 | +    **kwargs, | 
|  | 19 | +) -> xr.DataArray: | 
|  | 20 | +    """Compute distances between two position arrays across a given dimension. | 
|  | 21 | +
 | 
|  | 22 | +    This function is a wrapper around :func:`scipy.spatial.distance.cdist` | 
|  | 23 | +    and computes the pairwise distances between the two input position arrays | 
|  | 24 | +    across the dimension specified by ``dim``. | 
|  | 25 | +    The dimension can be either ``individuals`` or ``keypoints``. | 
|  | 26 | +    The distances are computed using the specified ``metric``. | 
|  | 27 | +
 | 
|  | 28 | +    Parameters | 
|  | 29 | +    ---------- | 
|  | 30 | +    a : xarray.DataArray | 
|  | 31 | +        The first input data containing position information of a | 
|  | 32 | +        single individual or keypoint, with ``time``, ``space`` | 
|  | 33 | +        (in Cartesian coordinates), and ``individuals`` or ``keypoints`` | 
|  | 34 | +        (as specified by ``dim``) as required dimensions. | 
|  | 35 | +    b : xarray.DataArray | 
|  | 36 | +        The second input data containing position information of a | 
|  | 37 | +        single individual or keypoint, with ``time``, ``space`` | 
|  | 38 | +        (in Cartesian coordinates), and ``individuals`` or ``keypoints`` | 
|  | 39 | +        (as specified by ``dim``) as required dimensions. | 
|  | 40 | +    dim : str | 
|  | 41 | +        The dimension to compute the distances for. Must be either | 
|  | 42 | +        ``'individuals'`` or ``'keypoints'``. | 
|  | 43 | +    metric : str, optional | 
|  | 44 | +        The distance metric to use. Must be one of the options supported | 
|  | 45 | +        by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, | 
|  | 46 | +        ``'euclidean'``, etc. | 
|  | 47 | +        Defaults to ``'euclidean'``. | 
|  | 48 | +    **kwargs : dict | 
|  | 49 | +        Additional keyword arguments to pass to | 
|  | 50 | +        :func:`scipy.spatial.distance.cdist`. | 
|  | 51 | +
 | 
|  | 52 | +
 | 
|  | 53 | +    Returns | 
|  | 54 | +    ------- | 
|  | 55 | +    xarray.DataArray | 
|  | 56 | +        An xarray DataArray containing the computed distances between | 
|  | 57 | +        each pair of inputs. | 
|  | 58 | +
 | 
|  | 59 | +    Examples | 
|  | 60 | +    -------- | 
|  | 61 | +    Compute the Euclidean distance (default) between ``ind1`` and | 
|  | 62 | +    ``ind2`` (i.e. interindividual distance for all keypoints) | 
|  | 63 | +    using the ``position`` data variable in the Dataset ``ds``: | 
|  | 64 | +
 | 
|  | 65 | +    >>> pos1 = ds.position.sel(individuals="ind1") | 
|  | 66 | +    >>> pos2 = ds.position.sel(individuals="ind2") | 
|  | 67 | +    >>> ind_dists = _cdist(pos1, pos2, dim="individuals") | 
|  | 68 | +
 | 
|  | 69 | +    Compute the Euclidean distance (default) between ``key1`` and | 
|  | 70 | +    ``key2`` (i.e. interkeypoint distance for all individuals) | 
|  | 71 | +    using the ``position`` data variable in the Dataset ``ds``: | 
|  | 72 | +
 | 
|  | 73 | +    >>> pos1 = ds.position.sel(keypoints="key1") | 
|  | 74 | +    >>> pos2 = ds.position.sel(keypoints="key2") | 
|  | 75 | +    >>> key_dists = _cdist(pos1, pos2, dim="keypoints") | 
|  | 76 | +
 | 
|  | 77 | +    See Also | 
|  | 78 | +    -------- | 
|  | 79 | +    scipy.spatial.distance.cdist : The underlying function used. | 
|  | 80 | +    compute_pairwise_distances : Compute pairwise distances between | 
|  | 81 | +        ``individuals`` or ``keypoints`` | 
|  | 82 | +
 | 
|  | 83 | +    """ | 
|  | 84 | +    # The dimension from which ``dim`` labels are obtained | 
|  | 85 | +    labels_dim = "individuals" if dim == "keypoints" else "keypoints" | 
|  | 86 | +    elem1 = getattr(a, dim).item() | 
|  | 87 | +    elem2 = getattr(b, dim).item() | 
|  | 88 | +    a = _validate_labels_dimension(a, labels_dim) | 
|  | 89 | +    b = _validate_labels_dimension(b, labels_dim) | 
|  | 90 | +    result = xr.apply_ufunc( | 
|  | 91 | +        cdist, | 
|  | 92 | +        a, | 
|  | 93 | +        b, | 
|  | 94 | +        kwargs={"metric": metric, **kwargs}, | 
|  | 95 | +        input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]], | 
|  | 96 | +        output_core_dims=[[elem1, elem2]], | 
|  | 97 | +        vectorize=True, | 
|  | 98 | +    ) | 
|  | 99 | +    result = result.assign_coords( | 
|  | 100 | +        { | 
|  | 101 | +            elem1: getattr(a, labels_dim).values, | 
|  | 102 | +            elem2: getattr(a, labels_dim).values, | 
|  | 103 | +        } | 
|  | 104 | +    ) | 
|  | 105 | +    result.name = "distance" | 
|  | 106 | +    # Drop any squeezed coordinates | 
|  | 107 | +    return result.squeeze(drop=True) | 
|  | 108 | + | 
|  | 109 | + | 
|  | 110 | +def compute_pairwise_distances( | 
|  | 111 | +    data: xr.DataArray, | 
|  | 112 | +    dim: Literal["individuals", "keypoints"], | 
|  | 113 | +    pairs: dict[str, str | list[str]] | Literal["all"], | 
|  | 114 | +    metric: str | None = "euclidean", | 
|  | 115 | +    **kwargs, | 
|  | 116 | +) -> xr.DataArray | dict[str, xr.DataArray]: | 
|  | 117 | +    """Compute pairwise distances between ``individuals`` or ``keypoints``. | 
|  | 118 | +
 | 
|  | 119 | +    This function computes the distances between | 
|  | 120 | +    pairs of ``individuals`` (i.e. interindividual distances) or | 
|  | 121 | +    pairs of ``keypoints`` (i.e. interkeypoint distances), | 
|  | 122 | +    as determined by ``dim``. | 
|  | 123 | +    The distances are computed for the given ``pairs`` | 
|  | 124 | +    using the specified ``metric``. | 
|  | 125 | +
 | 
|  | 126 | +    Parameters | 
|  | 127 | +    ---------- | 
|  | 128 | +    data : xarray.DataArray | 
|  | 129 | +        The input data containing position information, with ``time``, | 
|  | 130 | +        ``space`` (in Cartesian coordinates), and | 
|  | 131 | +        ``individuals`` or ``keypoints`` (as specified by ``dim``) | 
|  | 132 | +        as required dimensions. | 
|  | 133 | +    dim : Literal["individuals", "keypoints"] | 
|  | 134 | +        The dimension to compute the distances for. Must be either | 
|  | 135 | +        ``'individuals'`` or ``'keypoints'``. | 
|  | 136 | +    pairs : dict[str, str | list[str]] or 'all' | 
|  | 137 | +        Specifies the pairs of elements (either individuals or keypoints) | 
|  | 138 | +        for which to compute distances, depending on the value of ``dim``. | 
|  | 139 | +
 | 
|  | 140 | +        - If ``dim='individuals'``, ``pairs`` should be a dictionary where | 
|  | 141 | +          each key is an individual name, and each value is also an individual | 
|  | 142 | +          name or a list of such names to compute distances with. | 
|  | 143 | +        - If ``dim='keypoints'``, ``pairs`` should be a dictionary where each | 
|  | 144 | +          key is a keypoint name, and each value is also keypoint name or a | 
|  | 145 | +          list of such names to compute distances with. | 
|  | 146 | +        - Alternatively, use the special keyword ``'all'`` to compute distances | 
|  | 147 | +          for all possible pairs of individuals or keypoints | 
|  | 148 | +          (depending on ``dim``). | 
|  | 149 | +    metric : str, optional | 
|  | 150 | +        The distance metric to use. Must be one of the options supported | 
|  | 151 | +        by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, | 
|  | 152 | +        ``'euclidean'``, etc. | 
|  | 153 | +        Defaults to ``'euclidean'``. | 
|  | 154 | +    **kwargs : dict | 
|  | 155 | +        Additional keyword arguments to pass to | 
|  | 156 | +        :func:`scipy.spatial.distance.cdist`. | 
|  | 157 | +
 | 
|  | 158 | +    Returns | 
|  | 159 | +    ------- | 
|  | 160 | +    xarray.DataArray or dict[str, xarray.DataArray] | 
|  | 161 | +        The computed pairwise distances. If a single pair is specified in | 
|  | 162 | +        ``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs | 
|  | 163 | +        are specified, returns a dictionary where each key is a string | 
|  | 164 | +        representing the pair  (e.g., ``'dist_ind1_ind2'`` or | 
|  | 165 | +        ``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray` | 
|  | 166 | +        containing the computed distances for that pair. | 
|  | 167 | +
 | 
|  | 168 | +    Raises | 
|  | 169 | +    ------ | 
|  | 170 | +    ValueError | 
|  | 171 | +        If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``; | 
|  | 172 | +        if ``pairs`` is not a dictionary or ``'all'``; or | 
|  | 173 | +        if there are no pairs in ``data`` to compute distances for. | 
|  | 174 | +
 | 
|  | 175 | +    Examples | 
|  | 176 | +    -------- | 
|  | 177 | +    Compute the Euclidean distance (default) between ``ind1`` and ``ind2`` | 
|  | 178 | +    (i.e. interindividual distance), for all possible pairs of keypoints. | 
|  | 179 | +
 | 
|  | 180 | +    >>> position = xr.DataArray( | 
|  | 181 | +    ...     np.arange(36).reshape(2, 3, 3, 2), | 
|  | 182 | +    ...     coords={ | 
|  | 183 | +    ...         "time": np.arange(2), | 
|  | 184 | +    ...         "individuals": ["ind1", "ind2", "ind3"], | 
|  | 185 | +    ...         "keypoints": ["key1", "key2", "key3"], | 
|  | 186 | +    ...         "space": ["x", "y"], | 
|  | 187 | +    ...     }, | 
|  | 188 | +    ...     dims=["time", "individuals", "keypoints", "space"], | 
|  | 189 | +    ... ) | 
|  | 190 | +    >>> dist_ind1_ind2 = compute_pairwise_distances( | 
|  | 191 | +    ...     position, "individuals", {"ind1": "ind2"} | 
|  | 192 | +    ... ) | 
|  | 193 | +    >>> dist_ind1_ind2 | 
|  | 194 | +    <xarray.DataArray (time: 2, ind1: 3, ind2: 3)> Size: 144B | 
|  | 195 | +    8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485 | 
|  | 196 | +    Coordinates: | 
|  | 197 | +    * time     (time) int64 16B 0 1 | 
|  | 198 | +    * ind1     (ind1) <U4 48B 'key1' 'key2' 'key3' | 
|  | 199 | +    * ind2     (ind2) <U4 48B 'key1' 'key2' 'key3' | 
|  | 200 | +
 | 
|  | 201 | +    The resulting ``dist_ind1_ind2`` is a DataArray containing the computed | 
|  | 202 | +    distances between ``ind1`` and ``ind2`` for all keypoints | 
|  | 203 | +    at each time point. | 
|  | 204 | +
 | 
|  | 205 | +    To obtain the distances between ``key1`` of ``ind1`` and | 
|  | 206 | +    ``key2`` of ``ind2``: | 
|  | 207 | +
 | 
|  | 208 | +    >>> dist_ind1_ind2.sel(ind1="key1", ind2="key2") | 
|  | 209 | +
 | 
|  | 210 | +    Compute the Euclidean distance (default) between ``key1`` and ``key2`` | 
|  | 211 | +    (i.e. interkeypoint distance), for all possible pairs of individuals. | 
|  | 212 | +
 | 
|  | 213 | +    >>> dist_key1_key2 = compute_pairwise_distances( | 
|  | 214 | +    ...     position, "keypoints", {"key1": "key2"} | 
|  | 215 | +    ... ) | 
|  | 216 | +    >>> dist_key1_key2 | 
|  | 217 | +    <xarray.DataArray (time: 2, key1: 3, key2: 3)> Size: 144B | 
|  | 218 | +    2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828 | 
|  | 219 | +    Coordinates: | 
|  | 220 | +    * time     (time) int64 16B 0 1 | 
|  | 221 | +    * key1     (key1) <U4 48B 'ind1' 'ind2' 'ind3' | 
|  | 222 | +    * key2     (key2) <U4 48B 'ind1' 'ind2' 'ind3' | 
|  | 223 | +
 | 
|  | 224 | +    The resulting ``dist_key1_key2`` is a DataArray containing the computed | 
|  | 225 | +    distances between ``key1`` and ``key2`` for all individuals | 
|  | 226 | +    at each time point. | 
|  | 227 | +
 | 
|  | 228 | +    To obtain the distances between ``key1`` and ``key2`` within ``ind1``: | 
|  | 229 | +
 | 
|  | 230 | +    >>> dist_key1_key2.sel(key1="ind1", key2="ind1") | 
|  | 231 | +
 | 
|  | 232 | +    To obtain the distances between ``key1`` of ``ind1`` and | 
|  | 233 | +    ``key2`` of ``ind2``: | 
|  | 234 | +
 | 
|  | 235 | +    >>> dist_key1_key2.sel(key1="ind1", key2="ind2") | 
|  | 236 | +
 | 
|  | 237 | +    Compute the city block or Manhattan distance for multiple pairs of | 
|  | 238 | +    keypoints using ``position``: | 
|  | 239 | +
 | 
|  | 240 | +    >>> key_dists = compute_pairwise_distances( | 
|  | 241 | +    ...     position, | 
|  | 242 | +    ...     "keypoints", | 
|  | 243 | +    ...     {"key1": "key2", "key3": ["key1", "key2"]}, | 
|  | 244 | +    ...     metric="cityblock", | 
|  | 245 | +    ... ) | 
|  | 246 | +    >>> key_dists.keys() | 
|  | 247 | +    dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2']) | 
|  | 248 | +
 | 
|  | 249 | +    As multiple pairs of keypoints are specified, | 
|  | 250 | +    the resulting ``key_dists`` is a dictionary containing the DataArrays | 
|  | 251 | +    of computed distances for each pair of keypoints. | 
|  | 252 | +
 | 
|  | 253 | +    Compute the city block or Manhattan distance for all possible pairs of | 
|  | 254 | +    individuals using ``position``: | 
|  | 255 | +
 | 
|  | 256 | +    >>> ind_dists = compute_pairwise_distances( | 
|  | 257 | +    ...     position, | 
|  | 258 | +    ...     "individuals", | 
|  | 259 | +    ...     "all", | 
|  | 260 | +    ...     metric="cityblock", | 
|  | 261 | +    ... ) | 
|  | 262 | +    >>> ind_dists.keys() | 
|  | 263 | +    dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3']) | 
|  | 264 | +
 | 
|  | 265 | +    See Also | 
|  | 266 | +    -------- | 
|  | 267 | +    scipy.spatial.distance.cdist : The underlying function used. | 
|  | 268 | +
 | 
|  | 269 | +    """ | 
|  | 270 | +    if dim not in ["individuals", "keypoints"]: | 
|  | 271 | +        raise logger.error( | 
|  | 272 | +            ValueError( | 
|  | 273 | +                "'dim' must be either 'individuals' or 'keypoints', " | 
|  | 274 | +                f"but got {dim}." | 
|  | 275 | +            ) | 
|  | 276 | +        ) | 
|  | 277 | +    if isinstance(pairs, str) and pairs != "all": | 
|  | 278 | +        raise logger.error( | 
|  | 279 | +            ValueError( | 
|  | 280 | +                f"'pairs' must be a dictionary or 'all', but got {pairs}." | 
|  | 281 | +            ) | 
|  | 282 | +        ) | 
|  | 283 | +    validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []}) | 
|  | 284 | +    # Find all possible pair combinations if 'all' is specified | 
|  | 285 | +    if pairs == "all": | 
|  | 286 | +        paired_elements = list( | 
|  | 287 | +            itertools.combinations(getattr(data, dim).values, 2) | 
|  | 288 | +        ) | 
|  | 289 | +    else: | 
|  | 290 | +        paired_elements = [ | 
|  | 291 | +            (elem1, elem2) | 
|  | 292 | +            for elem1, elem2_list in pairs.items() | 
|  | 293 | +            for elem2 in ( | 
|  | 294 | +                # Ensure elem2_list is a list | 
|  | 295 | +                [elem2_list] if isinstance(elem2_list, str) else elem2_list | 
|  | 296 | +            ) | 
|  | 297 | +        ] | 
|  | 298 | +    if not paired_elements: | 
|  | 299 | +        raise logger.error( | 
|  | 300 | +            ValueError("Could not find any pairs to compute distances for.") | 
|  | 301 | +        ) | 
|  | 302 | +    pairwise_distances = { | 
|  | 303 | +        f"dist_{elem1}_{elem2}": _cdist( | 
|  | 304 | +            data.sel({dim: elem1}), | 
|  | 305 | +            data.sel({dim: elem2}), | 
|  | 306 | +            dim=dim, | 
|  | 307 | +            metric=metric, | 
|  | 308 | +            **kwargs, | 
|  | 309 | +        ) | 
|  | 310 | +        for elem1, elem2 in paired_elements | 
|  | 311 | +    } | 
|  | 312 | +    # Return DataArray if result only has one key | 
|  | 313 | +    if len(pairwise_distances) == 1: | 
|  | 314 | +        return next(iter(pairwise_distances.values())) | 
|  | 315 | +    return pairwise_distances | 
|  | 316 | + | 
|  | 317 | + | 
|  | 318 | +def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray: | 
|  | 319 | +    """Validate the input data contains the ``dim`` for labelling dimensions. | 
|  | 320 | +
 | 
|  | 321 | +    This function ensures the input data contains the ``dim`` | 
|  | 322 | +    used as labels (coordinates) when applying | 
|  | 323 | +    :func:`scipy.spatial.distance.cdist` to | 
|  | 324 | +    the input data, by adding a temporary dimension if necessary. | 
|  | 325 | +
 | 
|  | 326 | +    Parameters | 
|  | 327 | +    ---------- | 
|  | 328 | +    data : xarray.DataArray | 
|  | 329 | +        The input data to validate. | 
|  | 330 | +    dim : str | 
|  | 331 | +        The dimension to validate. | 
|  | 332 | +
 | 
|  | 333 | +    Returns | 
|  | 334 | +    ------- | 
|  | 335 | +    xarray.DataArray | 
|  | 336 | +        The input data with the labels dimension validated. | 
|  | 337 | +
 | 
|  | 338 | +    """ | 
|  | 339 | +    if data.coords.get(dim) is None: | 
|  | 340 | +        data = data.assign_coords({dim: "temp_dim"}) | 
|  | 341 | +    if data.coords[dim].ndim == 0: | 
|  | 342 | +        data = data.expand_dims(dim).transpose("time", "space", dim) | 
|  | 343 | +    return data | 
0 commit comments