Skip to content

Commit fea41df

Browse files
niksirbiShigrafS
andauthored
Split kinematics.py and corresponding tests into multiple files (#583)
* split kinematics.py and corresponding tests into multiple files Co-authored-by: Shigraf Salik <salikshigraf@gmail.com> * more general docstring for kinematics/__init__.py --------- Co-authored-by: Shigraf Salik <salikshigraf@gmail.com>
1 parent ee381dd commit fea41df

File tree

9 files changed

+1995
-1928
lines changed

9 files changed

+1995
-1928
lines changed

movement/kinematics.py

Lines changed: 0 additions & 978 deletions
This file was deleted.

movement/kinematics/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Compute variables derived from ``position`` data."""
2+
3+
from movement.kinematics.distances import compute_pairwise_distances
4+
from movement.kinematics.kinematics import (
5+
compute_acceleration,
6+
compute_displacement,
7+
compute_path_length,
8+
compute_speed,
9+
compute_time_derivative,
10+
compute_velocity,
11+
)
12+
from movement.kinematics.orientation import (
13+
compute_forward_vector,
14+
compute_forward_vector_angle,
15+
compute_head_direction_vector,
16+
)
17+
18+
__all__ = [
19+
"compute_displacement",
20+
"compute_velocity",
21+
"compute_acceleration",
22+
"compute_speed",
23+
"compute_path_length",
24+
"compute_time_derivative",
25+
"compute_pairwise_distances",
26+
"compute_forward_vector",
27+
"compute_head_direction_vector",
28+
"compute_forward_vector_angle",
29+
]

movement/kinematics/distances.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
"""Computing spatial relationships between points, such as distances."""
2+
3+
import itertools
4+
from typing import Literal
5+
6+
import xarray as xr
7+
from scipy.spatial.distance import cdist
8+
9+
from movement.utils.logging import logger
10+
from movement.validators.arrays import validate_dims_coords
11+
12+
13+
def _cdist(
14+
a: xr.DataArray,
15+
b: xr.DataArray,
16+
dim: Literal["individuals", "keypoints"],
17+
metric: str | None = "euclidean",
18+
**kwargs,
19+
) -> xr.DataArray:
20+
"""Compute distances between two position arrays across a given dimension.
21+
22+
This function is a wrapper around :func:`scipy.spatial.distance.cdist`
23+
and computes the pairwise distances between the two input position arrays
24+
across the dimension specified by ``dim``.
25+
The dimension can be either ``individuals`` or ``keypoints``.
26+
The distances are computed using the specified ``metric``.
27+
28+
Parameters
29+
----------
30+
a : xarray.DataArray
31+
The first input data containing position information of a
32+
single individual or keypoint, with ``time``, ``space``
33+
(in Cartesian coordinates), and ``individuals`` or ``keypoints``
34+
(as specified by ``dim``) as required dimensions.
35+
b : xarray.DataArray
36+
The second input data containing position information of a
37+
single individual or keypoint, with ``time``, ``space``
38+
(in Cartesian coordinates), and ``individuals`` or ``keypoints``
39+
(as specified by ``dim``) as required dimensions.
40+
dim : str
41+
The dimension to compute the distances for. Must be either
42+
``'individuals'`` or ``'keypoints'``.
43+
metric : str, optional
44+
The distance metric to use. Must be one of the options supported
45+
by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
46+
``'euclidean'``, etc.
47+
Defaults to ``'euclidean'``.
48+
**kwargs : dict
49+
Additional keyword arguments to pass to
50+
:func:`scipy.spatial.distance.cdist`.
51+
52+
53+
Returns
54+
-------
55+
xarray.DataArray
56+
An xarray DataArray containing the computed distances between
57+
each pair of inputs.
58+
59+
Examples
60+
--------
61+
Compute the Euclidean distance (default) between ``ind1`` and
62+
``ind2`` (i.e. interindividual distance for all keypoints)
63+
using the ``position`` data variable in the Dataset ``ds``:
64+
65+
>>> pos1 = ds.position.sel(individuals="ind1")
66+
>>> pos2 = ds.position.sel(individuals="ind2")
67+
>>> ind_dists = _cdist(pos1, pos2, dim="individuals")
68+
69+
Compute the Euclidean distance (default) between ``key1`` and
70+
``key2`` (i.e. interkeypoint distance for all individuals)
71+
using the ``position`` data variable in the Dataset ``ds``:
72+
73+
>>> pos1 = ds.position.sel(keypoints="key1")
74+
>>> pos2 = ds.position.sel(keypoints="key2")
75+
>>> key_dists = _cdist(pos1, pos2, dim="keypoints")
76+
77+
See Also
78+
--------
79+
scipy.spatial.distance.cdist : The underlying function used.
80+
compute_pairwise_distances : Compute pairwise distances between
81+
``individuals`` or ``keypoints``
82+
83+
"""
84+
# The dimension from which ``dim`` labels are obtained
85+
labels_dim = "individuals" if dim == "keypoints" else "keypoints"
86+
elem1 = getattr(a, dim).item()
87+
elem2 = getattr(b, dim).item()
88+
a = _validate_labels_dimension(a, labels_dim)
89+
b = _validate_labels_dimension(b, labels_dim)
90+
result = xr.apply_ufunc(
91+
cdist,
92+
a,
93+
b,
94+
kwargs={"metric": metric, **kwargs},
95+
input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]],
96+
output_core_dims=[[elem1, elem2]],
97+
vectorize=True,
98+
)
99+
result = result.assign_coords(
100+
{
101+
elem1: getattr(a, labels_dim).values,
102+
elem2: getattr(a, labels_dim).values,
103+
}
104+
)
105+
result.name = "distance"
106+
# Drop any squeezed coordinates
107+
return result.squeeze(drop=True)
108+
109+
110+
def compute_pairwise_distances(
111+
data: xr.DataArray,
112+
dim: Literal["individuals", "keypoints"],
113+
pairs: dict[str, str | list[str]] | Literal["all"],
114+
metric: str | None = "euclidean",
115+
**kwargs,
116+
) -> xr.DataArray | dict[str, xr.DataArray]:
117+
"""Compute pairwise distances between ``individuals`` or ``keypoints``.
118+
119+
This function computes the distances between
120+
pairs of ``individuals`` (i.e. interindividual distances) or
121+
pairs of ``keypoints`` (i.e. interkeypoint distances),
122+
as determined by ``dim``.
123+
The distances are computed for the given ``pairs``
124+
using the specified ``metric``.
125+
126+
Parameters
127+
----------
128+
data : xarray.DataArray
129+
The input data containing position information, with ``time``,
130+
``space`` (in Cartesian coordinates), and
131+
``individuals`` or ``keypoints`` (as specified by ``dim``)
132+
as required dimensions.
133+
dim : Literal["individuals", "keypoints"]
134+
The dimension to compute the distances for. Must be either
135+
``'individuals'`` or ``'keypoints'``.
136+
pairs : dict[str, str | list[str]] or 'all'
137+
Specifies the pairs of elements (either individuals or keypoints)
138+
for which to compute distances, depending on the value of ``dim``.
139+
140+
- If ``dim='individuals'``, ``pairs`` should be a dictionary where
141+
each key is an individual name, and each value is also an individual
142+
name or a list of such names to compute distances with.
143+
- If ``dim='keypoints'``, ``pairs`` should be a dictionary where each
144+
key is a keypoint name, and each value is also keypoint name or a
145+
list of such names to compute distances with.
146+
- Alternatively, use the special keyword ``'all'`` to compute distances
147+
for all possible pairs of individuals or keypoints
148+
(depending on ``dim``).
149+
metric : str, optional
150+
The distance metric to use. Must be one of the options supported
151+
by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``,
152+
``'euclidean'``, etc.
153+
Defaults to ``'euclidean'``.
154+
**kwargs : dict
155+
Additional keyword arguments to pass to
156+
:func:`scipy.spatial.distance.cdist`.
157+
158+
Returns
159+
-------
160+
xarray.DataArray or dict[str, xarray.DataArray]
161+
The computed pairwise distances. If a single pair is specified in
162+
``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs
163+
are specified, returns a dictionary where each key is a string
164+
representing the pair (e.g., ``'dist_ind1_ind2'`` or
165+
``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray`
166+
containing the computed distances for that pair.
167+
168+
Raises
169+
------
170+
ValueError
171+
If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``;
172+
if ``pairs`` is not a dictionary or ``'all'``; or
173+
if there are no pairs in ``data`` to compute distances for.
174+
175+
Examples
176+
--------
177+
Compute the Euclidean distance (default) between ``ind1`` and ``ind2``
178+
(i.e. interindividual distance), for all possible pairs of keypoints.
179+
180+
>>> position = xr.DataArray(
181+
... np.arange(36).reshape(2, 3, 3, 2),
182+
... coords={
183+
... "time": np.arange(2),
184+
... "individuals": ["ind1", "ind2", "ind3"],
185+
... "keypoints": ["key1", "key2", "key3"],
186+
... "space": ["x", "y"],
187+
... },
188+
... dims=["time", "individuals", "keypoints", "space"],
189+
... )
190+
>>> dist_ind1_ind2 = compute_pairwise_distances(
191+
... position, "individuals", {"ind1": "ind2"}
192+
... )
193+
>>> dist_ind1_ind2
194+
<xarray.DataArray (time: 2, ind1: 3, ind2: 3)> Size: 144B
195+
8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485
196+
Coordinates:
197+
* time (time) int64 16B 0 1
198+
* ind1 (ind1) <U4 48B 'key1' 'key2' 'key3'
199+
* ind2 (ind2) <U4 48B 'key1' 'key2' 'key3'
200+
201+
The resulting ``dist_ind1_ind2`` is a DataArray containing the computed
202+
distances between ``ind1`` and ``ind2`` for all keypoints
203+
at each time point.
204+
205+
To obtain the distances between ``key1`` of ``ind1`` and
206+
``key2`` of ``ind2``:
207+
208+
>>> dist_ind1_ind2.sel(ind1="key1", ind2="key2")
209+
210+
Compute the Euclidean distance (default) between ``key1`` and ``key2``
211+
(i.e. interkeypoint distance), for all possible pairs of individuals.
212+
213+
>>> dist_key1_key2 = compute_pairwise_distances(
214+
... position, "keypoints", {"key1": "key2"}
215+
... )
216+
>>> dist_key1_key2
217+
<xarray.DataArray (time: 2, key1: 3, key2: 3)> Size: 144B
218+
2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828
219+
Coordinates:
220+
* time (time) int64 16B 0 1
221+
* key1 (key1) <U4 48B 'ind1' 'ind2' 'ind3'
222+
* key2 (key2) <U4 48B 'ind1' 'ind2' 'ind3'
223+
224+
The resulting ``dist_key1_key2`` is a DataArray containing the computed
225+
distances between ``key1`` and ``key2`` for all individuals
226+
at each time point.
227+
228+
To obtain the distances between ``key1`` and ``key2`` within ``ind1``:
229+
230+
>>> dist_key1_key2.sel(key1="ind1", key2="ind1")
231+
232+
To obtain the distances between ``key1`` of ``ind1`` and
233+
``key2`` of ``ind2``:
234+
235+
>>> dist_key1_key2.sel(key1="ind1", key2="ind2")
236+
237+
Compute the city block or Manhattan distance for multiple pairs of
238+
keypoints using ``position``:
239+
240+
>>> key_dists = compute_pairwise_distances(
241+
... position,
242+
... "keypoints",
243+
... {"key1": "key2", "key3": ["key1", "key2"]},
244+
... metric="cityblock",
245+
... )
246+
>>> key_dists.keys()
247+
dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2'])
248+
249+
As multiple pairs of keypoints are specified,
250+
the resulting ``key_dists`` is a dictionary containing the DataArrays
251+
of computed distances for each pair of keypoints.
252+
253+
Compute the city block or Manhattan distance for all possible pairs of
254+
individuals using ``position``:
255+
256+
>>> ind_dists = compute_pairwise_distances(
257+
... position,
258+
... "individuals",
259+
... "all",
260+
... metric="cityblock",
261+
... )
262+
>>> ind_dists.keys()
263+
dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3'])
264+
265+
See Also
266+
--------
267+
scipy.spatial.distance.cdist : The underlying function used.
268+
269+
"""
270+
if dim not in ["individuals", "keypoints"]:
271+
raise logger.error(
272+
ValueError(
273+
"'dim' must be either 'individuals' or 'keypoints', "
274+
f"but got {dim}."
275+
)
276+
)
277+
if isinstance(pairs, str) and pairs != "all":
278+
raise logger.error(
279+
ValueError(
280+
f"'pairs' must be a dictionary or 'all', but got {pairs}."
281+
)
282+
)
283+
validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []})
284+
# Find all possible pair combinations if 'all' is specified
285+
if pairs == "all":
286+
paired_elements = list(
287+
itertools.combinations(getattr(data, dim).values, 2)
288+
)
289+
else:
290+
paired_elements = [
291+
(elem1, elem2)
292+
for elem1, elem2_list in pairs.items()
293+
for elem2 in (
294+
# Ensure elem2_list is a list
295+
[elem2_list] if isinstance(elem2_list, str) else elem2_list
296+
)
297+
]
298+
if not paired_elements:
299+
raise logger.error(
300+
ValueError("Could not find any pairs to compute distances for.")
301+
)
302+
pairwise_distances = {
303+
f"dist_{elem1}_{elem2}": _cdist(
304+
data.sel({dim: elem1}),
305+
data.sel({dim: elem2}),
306+
dim=dim,
307+
metric=metric,
308+
**kwargs,
309+
)
310+
for elem1, elem2 in paired_elements
311+
}
312+
# Return DataArray if result only has one key
313+
if len(pairwise_distances) == 1:
314+
return next(iter(pairwise_distances.values()))
315+
return pairwise_distances
316+
317+
318+
def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray:
319+
"""Validate the input data contains the ``dim`` for labelling dimensions.
320+
321+
This function ensures the input data contains the ``dim``
322+
used as labels (coordinates) when applying
323+
:func:`scipy.spatial.distance.cdist` to
324+
the input data, by adding a temporary dimension if necessary.
325+
326+
Parameters
327+
----------
328+
data : xarray.DataArray
329+
The input data to validate.
330+
dim : str
331+
The dimension to validate.
332+
333+
Returns
334+
-------
335+
xarray.DataArray
336+
The input data with the labels dimension validated.
337+
338+
"""
339+
if data.coords.get(dim) is None:
340+
data = data.assign_coords({dim: "temp_dim"})
341+
if data.coords[dim].ndim == 0:
342+
data = data.expand_dims(dim).transpose("time", "space", dim)
343+
return data

0 commit comments

Comments
 (0)