Skip to content

Commit 5b4056a

Browse files
committed
Add test for loading DLC 3D pose data
1 parent ce58164 commit 5b4056a

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

tests/fixtures/datasets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,24 @@ def valid_dlc_poses_df():
312312
return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5"))
313313

314314

315+
@pytest.fixture
316+
def valid_dlc_3d_poses_df(valid_dlc_poses_df):
317+
"""Mock and return a valid DLC-style 3D poses DataFrame.
318+
319+
The only difference between 2D and 3D DLC DataFrames is that
320+
the coordinate level in the columns MultiIndex includes 'z' instead of
321+
'likelihood'.
322+
"""
323+
cols = [
324+
(scorer, bodypart, "z" if coord == "likelihood" else coord)
325+
for scorer, bodypart, coord in valid_dlc_poses_df.columns.to_list()
326+
]
327+
valid_dlc_poses_df.columns = pd.MultiIndex.from_tuples(
328+
cols, names=valid_dlc_poses_df.columns.names
329+
)
330+
return valid_dlc_poses_df
331+
332+
315333
# -------------------- Invalid bboxes datasets --------------------
316334
@pytest.fixture
317335
def missing_var_bboxes_dataset(valid_bboxes_dataset):

tests/test_unit/test_io/test_load_poses.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,20 @@ def test_load_from_dlc_file(file_name, helpers):
9999

100100

101101
@pytest.mark.parametrize(
102-
"source_software", ["DeepLabCut", "LightningPose", None]
102+
"poses_df_fixture, source_software",
103+
[
104+
("valid_dlc_poses_df", "DeepLabCut"),
105+
("valid_dlc_poses_df", "LightningPose"),
106+
("valid_dlc_poses_df", None),
107+
("valid_dlc_3d_poses_df", "DeepLabCut"),
108+
],
103109
)
104-
def test_load_from_dlc_style_df(valid_dlc_poses_df, source_software, helpers):
105-
"""Test that loading pose tracks from a valid DLC-style DataFrame
106-
returns a proper Dataset.
107-
"""
108-
ds = load_poses.from_dlc_style_df(
109-
valid_dlc_poses_df, source_software=source_software
110-
)
110+
def test_load_from_dlc_style_df(
111+
poses_df_fixture, source_software, helpers, request
112+
):
113+
"""Test loading pose tracks from DLC-style DataFrames (2D and 3D)."""
114+
df = request.getfixturevalue(poses_df_fixture)
115+
ds = load_poses.from_dlc_style_df(df, source_software=source_software)
111116
expected_values = {
112117
**expected_values_poses,
113118
"source_software": source_software,

0 commit comments

Comments
 (0)