Skip to content

Commit f2b539d

Browse files
authored
Ensure derived data variables have correct name (#523)
* Set name attribute for derived variables * Add names for forward_vector and signed angles * Add tests to check if name is correctly set * Add remaining tests * Rename head_direction to head_direction_vector
1 parent 529403f commit f2b539d

File tree

5 files changed

+55
-9
lines changed

5 files changed

+55
-9
lines changed

movement/kinematics.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
6060
validate_dims_coords(data, {"time": [], "space": []})
6161
result = data.diff(dim="time")
6262
result = result.reindex(data.coords, fill_value=0)
63+
result.name = "displacement"
6364
return result
6465

6566

@@ -99,7 +100,9 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
99100
# validate only presence of Cartesian space dimension
100101
# (presence of time dimension will be checked in compute_time_derivative)
101102
validate_dims_coords(data, {"space": []})
102-
return compute_time_derivative(data, order=1)
103+
result = compute_time_derivative(data, order=1)
104+
result.name = "velocity"
105+
return result
103106

104107

105108
def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
@@ -139,7 +142,9 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
139142
# validate only presence of Cartesian space dimension
140143
# (presence of time dimension will be checked in compute_time_derivative)
141144
validate_dims_coords(data, {"space": []})
142-
return compute_time_derivative(data, order=2)
145+
result = compute_time_derivative(data, order=2)
146+
result.name = "acceleration"
147+
return result
143148

144149

145150
def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray:
@@ -202,7 +207,9 @@ def compute_speed(data: xr.DataArray) -> xr.DataArray:
202207
except ``space`` is removed.
203208
204209
"""
205-
return compute_norm(compute_velocity(data))
210+
result = compute_norm(compute_velocity(data))
211+
result.name = "speed"
212+
return result
206213

207214

208215
def compute_forward_vector(
@@ -312,7 +319,9 @@ def compute_forward_vector(
312319
space="z"
313320
) # keep only the first 2 spatal dimensions of the result
314321
# Return unit vector
315-
return convert_to_unit(forward_vector)
322+
result = convert_to_unit(forward_vector)
323+
result.name = "forward_vector"
324+
return result
316325

317326

318327
def compute_head_direction_vector(
@@ -353,9 +362,11 @@ def compute_head_direction_vector(
353362
``keypoints`` dimension.
354363
355364
"""
356-
return compute_forward_vector(
365+
result = compute_forward_vector(
357366
data, left_keypoint, right_keypoint, camera_view=camera_view
358367
)
368+
result.name = "head_direction_vector"
369+
return result
359370

360371

361372
def compute_forward_vector_angle(
@@ -439,6 +450,7 @@ def compute_forward_vector_angle(
439450
if in_degrees:
440451
heading_array = np.rad2deg(heading_array)
441452

453+
heading_array.name = "forward_vector_angle"
442454
return heading_array
443455

444456

@@ -534,6 +546,7 @@ def _cdist(
534546
elem2: getattr(a, labels_dim).values,
535547
}
536548
)
549+
result.name = "distance"
537550
# Drop any squeezed coordinates
538551
return result.squeeze(drop=True)
539552

@@ -873,13 +886,13 @@ def compute_path_length(
873886
_warn_about_nan_proportion(data, nan_warn_threshold)
874887

875888
if nan_policy == "ffill":
876-
return compute_norm(
889+
result = compute_norm(
877890
compute_displacement(data.ffill(dim="time")).isel(
878891
time=slice(1, None)
879892
) # skip first displacement (always 0)
880893
).sum(dim="time", min_count=1) # return NaN if no valid segment
881894
elif nan_policy == "scale":
882-
return _compute_scaled_path_length(data)
895+
result = _compute_scaled_path_length(data)
883896
else:
884897
raise logger.error(
885898
ValueError(
@@ -888,6 +901,9 @@ def compute_path_length(
888901
)
889902
)
890903

904+
result.name = "path_length"
905+
return result
906+
891907

892908
def _warn_about_nan_proportion(
893909
data: xr.DataArray, nan_warn_threshold: float

movement/utils/vector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def compute_signed_angle_2d(
274274
# arctan2 returns values in [-pi, pi].
275275
# We need to map -pi angles to pi, to stay in the (-pi, pi] range
276276
angles.values[angles <= -np.pi] = np.pi
277+
angles.name = "signed_angle"
277278
return angles
278279

279280

tests/test_integration/test_kinematics_vector_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def test_cart2pol_transform_on_kinematics(
6161
kinematic_array_cart = getattr(kin, f"compute_{kinematic_variable}")(
6262
ds.position
6363
)
64+
assert kinematic_array_cart.name == kinematic_variable
6465
kinematic_array_pol = vector.cart2pol(kinematic_array_cart)
6566

6667
# Build expected data array

tests/test_unit/test_kinematics.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_kinematics(self, valid_dataset, kinematic_variable, request):
4444
kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")(
4545
position
4646
)
47+
assert kinematic_array.name == kinematic_variable
4748
# Figure out which dimensions to expect in kinematic_array
4849
# and in the final xarray.DataArray
4950
expected_dims = ["time", "individuals"]
@@ -112,6 +113,7 @@ def test_kinematics_with_dataset_with_nans(
112113
kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")(
113114
position
114115
)
116+
assert kinematic_array.name == kinematic_variable
115117
# compute n nans in kinematic array per individual
116118
n_nans_kinematics_per_indiv = [
117119
helpers.count_nans(kinematic_array.isel(individuals=i))
@@ -203,6 +205,7 @@ def test_path_length_across_time_ranges(
203205
path_length = kinematics.compute_path_length(
204206
position, start=start, stop=stop
205207
)
208+
assert path_length.name == "path_length"
206209

207210
# Expected number of segments (displacements) in selected time range
208211
num_segments = 9 # full time range: 10 frames - 1
@@ -266,6 +269,7 @@ def test_path_length_with_nan(
266269
path_length = kinematics.compute_path_length(
267270
position, nan_policy=nan_policy
268271
)
272+
assert path_length.name == "path_length"
269273
# Get path_length for individual "id_0" as a numpy array
270274
path_length_id_0 = path_length.sel(individuals="id_0").values
271275
# Check them against the expected values
@@ -312,9 +316,10 @@ def test_path_length_nan_warn_threshold(
312316
"""
313317
position = valid_poses_dataset_with_nan.position
314318
with expected_exception:
315-
kinematics.compute_path_length(
319+
result = kinematics.compute_path_length(
316320
position, nan_warn_threshold=nan_warn_threshold
317321
)
322+
assert result.name == "path_length"
318323

319324

320325
@pytest.fixture
@@ -408,6 +413,10 @@ def test_compute_forward_vector(valid_data_array_for_forward_vector):
408413
"right_ear",
409414
camera_view="bottom_up",
410415
)
416+
assert forward_vector.name == "forward_vector"
417+
assert forward_vector_flipped.name == "forward_vector"
418+
assert head_vector.name == "head_direction_vector"
419+
411420
known_vectors = np.array([[[0, -1]], [[1, 0]], [[0, 1]], [[-1, 0]]])
412421

413422
for output_array in [forward_vector, forward_vector_flipped, head_vector]:
@@ -480,6 +489,7 @@ def test_nan_behavior_forward_vector(
480489
forward_vector = kinematics.compute_forward_vector(
481490
valid_data_array_for_forward_vector_with_nan, "left_ear", "right_ear"
482491
)
492+
assert forward_vector.name == "forward_vector"
483493
# Check coord preservation
484494
for preserved_coord in ["time", "space", "individuals"]:
485495
assert np.all(
@@ -552,6 +562,7 @@ def test_cdist_with_known_values(dim, expected_data, valid_poses_dataset):
552562
a = input_dataarray.sel({dim: pairs[0]})
553563
b = input_dataarray.sel({dim: pairs[1]})
554564
result = kinematics._cdist(a, b, dim)
565+
assert result.name == "distance"
555566
xr.testing.assert_equal(
556567
result,
557568
expected,
@@ -624,7 +635,9 @@ def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request):
624635
valid_dataset = request.getfixturevalue(valid_dataset)
625636
position = valid_dataset.position
626637
a, b = selection_fn(position)
627-
assert isinstance(kinematics._cdist(a, b, "individuals"), xr.DataArray)
638+
result = kinematics._cdist(a, b, "individuals")
639+
assert result.name == "distance"
640+
assert isinstance(result, xr.DataArray)
628641

629642

630643
@pytest.mark.parametrize(
@@ -662,12 +675,16 @@ def test_compute_pairwise_distances_with_valid_pairs(
662675
valid_poses_dataset.position, dim, pairs
663676
)
664677
if isinstance(result, dict):
678+
for _, value in result.items():
679+
assert isinstance(value, xr.DataArray)
680+
assert value.name == "distance"
665681
expected_data_vars = [
666682
f"dist_{pair[0]}_{pair[1]}" for pair in expected_data_vars
667683
]
668684
assert set(result.keys()) == set(expected_data_vars)
669685
else: # expect single DataArray
670686
assert isinstance(result, xr.DataArray)
687+
assert result.name == "distance"
671688

672689

673690
@pytest.mark.parametrize(
@@ -803,6 +820,8 @@ def test_antisymmetry_properties(
803820
right_keypoint=right_keypoint,
804821
reference_vector=reference_vector,
805822
)
823+
assert without_orientations_swapped.name == "forward_vector_angle"
824+
assert with_orientations_swapped.name == "forward_vector_angle"
806825

807826
expected_orientations = without_orientations_swapped.copy(deep=True)
808827
if swap_left_right:
@@ -841,6 +860,8 @@ def test_in_degrees_toggle(
841860
reference_vector=reference_vector,
842861
in_degrees=True,
843862
)
863+
assert in_radians.name == "forward_vector_angle"
864+
assert in_degrees.name == "forward_vector_angle"
844865

845866
xr.testing.assert_allclose(in_degrees, np.rad2deg(in_radians))
846867

@@ -901,6 +922,9 @@ def test_transformation_invariance(
901922
reference_vector=reference_vector,
902923
)
903924

925+
assert untranslated_output.name == "forward_vector_angle"
926+
assert translated_output.name == "forward_vector_angle"
927+
904928
xr.testing.assert_allclose(untranslated_output, translated_output)
905929

906930
def test_casts_from_tuple(

tests/test_unit/test_vector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ def test_compute_signed_angle_2d(
322322
left_vector, right_vector, v_as_left_operand=True
323323
)
324324

325+
assert computed_angles.name == "signed_angle"
326+
assert computed_angles_reversed.name == "signed_angle"
327+
325328
xr.testing.assert_allclose(computed_angles, expected_angles)
326329
xr.testing.assert_allclose(
327330
computed_angles_reversed, expected_angles_reversed
@@ -390,4 +393,5 @@ def test_multidimensional_input(
390393
)
391394

392395
computed_angles = vector.compute_signed_angle_2d(v_left, v_right)
396+
assert computed_angles.name == "signed_angle"
393397
xr.testing.assert_allclose(computed_angles, expected_angles)

0 commit comments

Comments
 (0)