Skip to content

Commit b5fb883

Browse files
committed
Support saving DLC 3D poses
1 parent 5b4056a commit b5fb883

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

movement/io/save_poses.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@ def _ds_to_dlc_style_df(
3737
pandas.DataFrame
3838
3939
"""
40-
# Concatenate the pose tracks and confidence scores into one array
41-
tracks_with_scores = np.concatenate(
42-
(
43-
ds.position.data,
44-
ds.confidence.data[:, np.newaxis, ...],
45-
),
46-
axis=1,
47-
)
40+
is_3d = "z" in columns.get_level_values("coords")
41+
if is_3d:
42+
tracks_with_scores = ds.position.data
43+
else:
44+
# Concatenate the pose tracks and confidence scores into one array
45+
tracks_with_scores = np.concatenate(
46+
(
47+
ds.position.data,
48+
ds.confidence.data[:, np.newaxis, ...],
49+
),
50+
axis=1,
51+
)
4852
# Reverse the order of the dimensions except for the time dimension
4953
transpose_order = [0] + list(range(tracks_with_scores.ndim - 1, 0, -1))
5054
tracks_with_scores = tracks_with_scores.transpose(transpose_order)
@@ -121,7 +125,12 @@ def to_dlc_style_df(
121125
_validate_dataset(ds, ValidPosesDataset)
122126
scorer = ["movement"]
123127
bodyparts = ds.coords["keypoints"].data.tolist()
124-
coords = ds.coords["space"].data.tolist() + ["likelihood"]
128+
base_coords = ds.coords["space"].data.tolist()
129+
coords = (
130+
base_coords
131+
if "z" in ds.coords["space"]
132+
else base_coords + ["likelihood"]
133+
)
125134
individuals = ds.coords["individuals"].data.tolist()
126135

127136
if split_individuals:

tests/test_integration/test_io.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ def dlc_output_file(request, tmp_path):
1313
return tmp_path / request.param
1414

1515

16-
def test_load_and_save_to_dlc_style_df(valid_dlc_poses_df):
16+
@pytest.mark.parametrize(
17+
"dlc_poses_df", ["valid_dlc_poses_df", "valid_dlc_3d_poses_df"]
18+
)
19+
def test_load_and_save_to_dlc_style_df(dlc_poses_df, request):
1720
"""Test that loading pose tracks from a DLC-style DataFrame and
1821
converting back to a DataFrame returns the same data values.
1922
"""
20-
ds = load_poses.from_dlc_style_df(valid_dlc_poses_df)
23+
dlc_poses_df = request.getfixturevalue(dlc_poses_df)
24+
ds = load_poses.from_dlc_style_df(dlc_poses_df)
2125
df = save_poses.to_dlc_style_df(ds, split_individuals=False)
22-
np.testing.assert_allclose(df.values, valid_dlc_poses_df.values)
26+
np.testing.assert_allclose(df.values, dlc_poses_df.values)
2327

2428

2529
def test_save_and_load_dlc_file(dlc_output_file, valid_poses_dataset):

0 commit comments

Comments
 (0)