Skip to content

Commit 65bb4b4

Browse files
authored
[Fix] Fix wrong displacement shape mentioned by issue (#112)
* Fix wrong displacement shape mentioned by issue
1 parent 833afc1 commit 65bb4b4

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

tests/data_structure/body_model/test_smplxd_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_new():
3030
transl=np.zeros(shape=(2, 3)),
3131
betas=np.zeros(shape=(2, 10)),
3232
mask=np.ones(shape=(2, )),
33-
displacement=np.zeros(shape=(2, 10475)),
33+
displacement=np.zeros(shape=(2, 10475, 3)),
3434
logger='root')
3535
assert smplxd_data['betas'][0, 0] == 0
3636
assert smplxd_data.get_expression().shape == (2, 10)
@@ -100,7 +100,7 @@ def test_setitem():
100100
smplxd_data['transl'] = np.zeros(shape=[2, 3])
101101
smplxd_data['fullpose'] = np.zeros(shape=[2, 55, 3])
102102
smplxd_data['gender'] = 'neutral'
103-
smplxd_data['displacement'] = np.zeros(shape=(2, 10475))
103+
smplxd_data['displacement'] = np.zeros(shape=(2, 10475, 3))
104104
# set arbitrary key
105105
smplxd_data['frame_number'] = 1000
106106

tests/model/registrant/test_smplifyxd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_smplifyxd_keypoints3d():
9898
np_v = v.detach().cpu().numpy()
9999
assert not np.any(np.isnan(np_v)), f'{k} fails.'
100100
smplxd_data.from_param_dict(smplifyxd_output)
101-
assert len(smplxd_data.get_displacement().shape) == 2
101+
assert len(smplxd_data.get_displacement().shape) == 3
102102
result_path = os.path.join(output_dir, 'smplxd_result.npz')
103103
smplxd_data.dump(result_path)
104104
# test not use_one_betas_per_video and return values

xrmocap/data_structure/body_model/smplxd_data.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self,
5454
zero-tensor in shape [frame_num, 10] will be created.
5555
displacement (Union[np.ndarray, torch.Tensor, None], optional):
5656
A tensor or ndarray for displacement,
57-
in shape [frame_num, NUM_VERTS].
57+
in shape [frame_num, NUM_VERTS, 3].
5858
Defaults to None,
5959
zero-tensor in shape [frame_num, NUM_VERTS] will be created.
6060
mask (Union[np.ndarray, torch.Tensor, None], optional):
@@ -77,7 +77,7 @@ def __init__(self,
7777
logger=logger)
7878
if displacement is None and 'displacement' not in self:
7979
displacement = np.zeros(
80-
shape=(self.get_batch_size(), self.__class__.NUM_VERTS))
80+
shape=(self.get_batch_size(), self.__class__.NUM_VERTS, 3))
8181
if displacement is not None:
8282
self.set_displacement(displacement)
8383

@@ -117,7 +117,7 @@ def set_displacement(
117117
Args:
118118
displacement (Union[np.ndarray, torch.Tensor]):
119119
Displacement parameters in ndarray or tensor,
120-
in shape [batch_size, NUM_VERTS].
120+
in shape [batch_size, NUM_VERTS, 3].
121121
122122
Raises:
123123
TypeError: Type of displacement is not correct.
@@ -128,18 +128,18 @@ def set_displacement(
128128
self.logger.error('Type of displacement is not correct.\n' +
129129
f'Type: {type(displacement)}.')
130130
raise TypeError
131-
if len(displacement.shape) == 1:
132-
displacement = displacement[np.newaxis, ...]
133-
displacement_dim = displacement.shape[-1]
134-
displacement_np = displacement.reshape(-1, displacement_dim)
135-
dict.__setitem__(self, 'displacement', displacement_np)
131+
if len(displacement.shape) < 3:
132+
self.logger.error('Shape of displacement is not correct.\n' +
133+
f'Shape: {type(displacement.shape)}.')
134+
raise ValueError
135+
dict.__setitem__(self, 'displacement', displacement)
136136

137137
def get_displacement(self) -> np.ndarray:
138138
"""Get displacement.
139139
140140
Returns:
141141
ndarray:
142-
Displacement in shape [batch_size, NUM_VERTS].
142+
Displacement in shape [batch_size, NUM_VERTS, 3].
143143
"""
144144
displacement = self.__getitem__('displacement')
145145
return displacement

0 commit comments

Comments
 (0)