@@ -54,7 +54,7 @@ def __init__(self,
54
54
zero-tensor in shape [frame_num, 10] will be created.
55
55
displacement (Union[np.ndarray, torch.Tensor, None], optional):
56
56
A tensor or ndarray for displacement,
57
- in shape [frame_num, NUM_VERTS].
57
+ in shape [frame_num, NUM_VERTS, 3 ].
58
58
Defaults to None,
59
59
zero-tensor in shape [frame_num, NUM_VERTS] will be created.
60
60
mask (Union[np.ndarray, torch.Tensor, None], optional):
@@ -77,7 +77,7 @@ def __init__(self,
77
77
logger = logger )
78
78
if displacement is None and 'displacement' not in self :
79
79
displacement = np .zeros (
80
- shape = (self .get_batch_size (), self .__class__ .NUM_VERTS ))
80
+ shape = (self .get_batch_size (), self .__class__ .NUM_VERTS , 3 ))
81
81
if displacement is not None :
82
82
self .set_displacement (displacement )
83
83
@@ -117,7 +117,7 @@ def set_displacement(
117
117
Args:
118
118
displacement (Union[np.ndarray, torch.Tensor]):
119
119
Displacement parameters in ndarray or tensor,
120
- in shape [batch_size, NUM_VERTS].
120
+ in shape [batch_size, NUM_VERTS, 3 ].
121
121
122
122
Raises:
123
123
TypeError: Type of displacement is not correct.
@@ -128,18 +128,18 @@ def set_displacement(
128
128
self .logger .error ('Type of displacement is not correct.\n ' +
129
129
f'Type: { type (displacement )} .' )
130
130
raise TypeError
131
- if len (displacement .shape ) == 1 :
132
- displacement = displacement [ np . newaxis , ...]
133
- displacement_dim = displacement .shape [ - 1 ]
134
- displacement_np = displacement . reshape ( - 1 , displacement_dim )
135
- dict .__setitem__ (self , 'displacement' , displacement_np )
131
+ if len (displacement .shape ) < 3 :
132
+ self . logger . error ( 'Shape of displacement is not correct. \n ' +
133
+ f'Shape: { type ( displacement .shape ) } .' )
134
+ raise ValueError
135
+ dict .__setitem__ (self , 'displacement' , displacement )
136
136
137
137
def get_displacement (self ) -> np .ndarray :
138
138
"""Get displacement.
139
139
140
140
Returns:
141
141
ndarray:
142
- Displacement in shape [batch_size, NUM_VERTS].
142
+ Displacement in shape [batch_size, NUM_VERTS, 3 ].
143
143
"""
144
144
displacement = self .__getitem__ ('displacement' )
145
145
return displacement
0 commit comments