Skip to content

Commit 1530d7c

Browse files
authored
[Add] Allow multiple optimizers and add extra joint priors (#56)
1. Allow multiple optimizers for different parameters 2. Map losses only to update relevant parameters 3. Add grad clipping to avoid `nan` loss in optimization 4. Add joint angle priors for A-pose 5. Add joint angle priors for foot, for insufficient foot keypoints 6. Fix joint angle limits for shoulder
1 parent 34b8c50 commit 1530d7c

File tree

4 files changed

+232
-77
lines changed

4 files changed

+232
-77
lines changed

xrmocap/model/loss/mapping.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
LOSS_MAPPING = {
2+
'keypoints3d_limb_len': ['betas'],
3+
'keypoints3d_mse': ['body_pose'],
4+
'keypoints2d_mse': ['body_pose'],
5+
'shape_prior': ['betas'],
6+
'joint_prior': ['body_pose'],
7+
'smooth_joint': ['body_pose'],
8+
'pose_reg': ['body_pose'],
9+
}

xrmocap/model/loss/prior_loss.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import pickle
1414

1515
from xrmocap.transform.convention.joints_convention.standard_joint_angles import ( # noqa:E501
16-
STANDARD_JOINT_ANGLE_LIMITS, TRANSFORMATION_AA_TO_SJA,
16+
STANDARD_JOINT_ANGLE_LIMITS, STANDARD_JOINT_ANGLE_LIMITS_LOCK_APOSE_SPINE,
17+
STANDARD_JOINT_ANGLE_LIMITS_LOCK_FOOT, TRANSFORMATION_AA_TO_SJA,
1718
TRANSFORMATION_SJA_TO_AA,
1819
)
1920
from xrmocap.transform.limbs import search_limbs
@@ -90,7 +91,11 @@ def __init__(self,
9091
loss_weight: float = 1.0,
9192
use_full_body: bool = False,
9293
smooth_spine: bool = False,
93-
smooth_spine_loss_weight: float = 1.0):
94+
lock_foot: bool = False,
95+
lock_apose_spine: bool = False,
96+
smooth_spine_loss_weight: float = 1.0,
97+
lock_foot_loss_weight: float = 1.0,
98+
lock_apose_spine_loss_weight: float = 1.0):
9499
"""Prior loss for joint angles.
95100
96101
Args:
@@ -110,17 +115,25 @@ def __init__(self,
110115
smooth spine loss. Defaults to 1.0.
111116
"""
112117
super().__init__()
113-
assert reduction in ('none', 'mean', 'sum')
118+
assert reduction in (None, 'none', 'mean', 'sum')
114119
self.reduction = reduction
115120
self.loss_weight = loss_weight
116121
self.use_full_body = use_full_body
117122
self.smooth_spine = smooth_spine
123+
self.lock_foot = lock_foot
124+
self.lock_apose_spine = lock_apose_spine
118125
self.smooth_spine_loss_weight = smooth_spine_loss_weight
126+
self.lock_foot_loss_weight = lock_foot_loss_weight
127+
self.lock_apose_spine_loss_weight = lock_apose_spine_loss_weight
119128

120129
if self.use_full_body:
121130
self.register_buffer('R_t', TRANSFORMATION_AA_TO_SJA)
122131
self.register_buffer('R_t_inv', TRANSFORMATION_SJA_TO_AA)
123132
self.register_buffer('sja_limits', STANDARD_JOINT_ANGLE_LIMITS)
133+
self.register_buffer('sja_lock_foot',
134+
STANDARD_JOINT_ANGLE_LIMITS_LOCK_FOOT)
135+
self.register_buffer('sja_apose_spine',
136+
STANDARD_JOINT_ANGLE_LIMITS_LOCK_APOSE_SPINE)
124137

125138
def forward(self,
126139
body_pose: torch.Tensor,
@@ -150,48 +163,67 @@ def forward(self,
150163
if loss_weight_override is not None \
151164
else self.loss_weight
152165

153-
if self.use_full_body:
154-
batch_size = body_pose.shape[0]
155-
body_pose_reshape = body_pose.reshape(batch_size, -1, 3)
156-
assert body_pose_reshape.shape[1] in (21, 23) # smpl-x, smpl
157-
body_pose_reshape = body_pose_reshape[:, :21, :]
158-
159-
body_pose_sja = aa_to_sja(body_pose_reshape, self.R_t,
160-
self.R_t_inv)
166+
batch_size = body_pose.shape[0]
167+
body_pose_reshape = body_pose.reshape(batch_size, -1, 3)
168+
assert body_pose_reshape.shape[1] in (21, 23) # smpl-x, smpl
169+
body_pose_reshape = body_pose_reshape[:, :21, :]
161170

162-
lower_limits = self.sja_limits[:, :, 0] # shape: (21, 3)
163-
upper_limits = self.sja_limits[:, :, 1] # shape: (21, 3)
171+
body_pose_sja = aa_to_sja(body_pose_reshape, self.R_t, self.R_t_inv)
164172

165-
lower_loss = (torch.exp(F.relu(lower_limits - body_pose_sja)) -
166-
1).pow(2)
167-
upper_loss = (torch.exp(F.relu(body_pose_sja - upper_limits)) -
168-
1).pow(2)
173+
parts_joint_prior_losses = []
174+
pred_poses = []
175+
limits = []
176+
weights = []
169177

170-
standard_joint_angle_prior_loss = (lower_loss + upper_loss).view(
171-
body_pose.shape[0], -1) # shape: (n, 3)
172-
173-
joint_prior_loss = standard_joint_angle_prior_loss
178+
if self.use_full_body:
179+
pred_poses.append(body_pose_sja)
180+
limits.append(self.sja_limits)
181+
weights.append(1.0)
174182

175183
else:
176184
# default joint prior loss applied on elbows and knees
177-
joint_prior_loss = (torch.exp(
178-
body_pose[:, [55, 58, 12, 15]] *
179-
torch.tensor([1., -1., -1, -1.], device=body_pose.device)) -
180-
1)**2
185+
pred_poses.append(body_pose_sja[:, [3, 4, 17, 18]])
186+
limits.append(self.sja_limits[[3, 4, 17, 18]])
187+
weights.append(1.0)
188+
189+
if self.lock_foot:
190+
pred_poses.append(body_pose_sja[:, [6, 7, 9, 10]])
191+
limits.append(self.sja_lock_foot)
192+
weights.append(self.lock_foot_loss_weight)
193+
194+
if self.lock_apose_spine:
195+
pred_poses.append(body_pose_sja[:, [2, 5, 8, 11]])
196+
limits.append(self.sja_apose_spine)
197+
weights.append(self.lock_apose_spine_loss_weight)
181198

182199
if self.smooth_spine:
183-
spine1 = body_pose[:, [9, 10, 11]]
184-
spine2 = body_pose[:, [18, 19, 20]]
185-
spine3 = body_pose[:, [27, 28, 29]]
200+
spine1 = body_pose_reshape[:, 2, :]
201+
spine2 = body_pose_reshape[:, 5, :]
202+
spine3 = body_pose_reshape[:, 8, :]
203+
186204
smooth_spine_loss_12 = (torch.exp(F.relu(-spine1 * spine2)) -
187205
1).pow(2) * self.smooth_spine_loss_weight
188206
smooth_spine_loss_23 = (torch.exp(F.relu(-spine2 * spine3)) -
189207
1).pow(2) * self.smooth_spine_loss_weight
190208

191-
joint_prior_loss = torch.cat(
192-
[joint_prior_loss, smooth_spine_loss_12, smooth_spine_loss_23],
193-
axis=1)
209+
parts_joint_prior_losses.append(smooth_spine_loss_12)
210+
parts_joint_prior_losses.append(smooth_spine_loss_23)
211+
212+
for idx, weight in enumerate(weights):
213+
lower_limits = limits[idx][:, :, 0]
214+
upper_limits = limits[idx][:, :, 1]
215+
pred_pose = pred_poses[idx]
216+
217+
lower_loss = (torch.exp(F.relu(lower_limits - pred_pose)) -
218+
1).pow(2)
219+
upper_loss = (torch.exp(F.relu(pred_pose - upper_limits)) -
220+
1).pow(2)
221+
loss = (lower_loss + upper_loss).view(body_pose.shape[0],
222+
-1) # (n, 3)
223+
224+
parts_joint_prior_losses.append(weight * loss)
194225

226+
joint_prior_loss = torch.cat(parts_joint_prior_losses, axis=1)
195227
joint_prior_loss = loss_weight * joint_prior_loss
196228

197229
if reduction == 'mean':

xrmocap/model/registrant/smplify.py

Lines changed: 134 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SMPLifyBaseHook, build_smplify_hook,
1313
)
1414
from xrmocap.model.body_model.builder import build_body_model
15+
from xrmocap.model.loss.mapping import LOSS_MAPPING
1516
from xrmocap.transform.convention.keypoints_convention import ( # noqa:E501
1617
get_keypoint_idx, get_keypoint_idxs_by_part,
1718
)
@@ -42,6 +43,7 @@ def __init__(self,
4243
hooks: List[Union[dict, SMPLifyBaseHook]] = [],
4344
verbose: bool = False,
4445
info_level: Literal['stage', 'step'] = 'step',
46+
grad_clip: float = 1.0,
4547
logger: Union[None, str, logging.Logger] = None) -> None:
4648
"""Re-implementation of SMPLify with extended features.
4749
@@ -92,7 +94,9 @@ def __init__(self,
9294
self.device = device
9395
self.stage_config = stages
9496
self.optimizer = optimizer
97+
self.grad_clip = grad_clip
9598
self.hooks = []
99+
self.individual_optimizer = False
96100

97101
# initialize body model
98102
if isinstance(body_model, dict):
@@ -352,44 +356,112 @@ def __optimize_stage__(self,
352356
self.call_hook('before_stage', **hook_kwargs)
353357

354358
kwargs = kwargs.copy()
355-
parameters = OptimizableParameters()
356-
for key, value in optim_param.items():
357-
fit_flag = kwargs.pop(f'fit_{key}', True)
358-
parameters.add_param(key=key, param=value, fit_param=fit_flag)
359-
optimizer = build_optimizer(parameters, self.optimizer)
360359

361-
pre_loss = None
360+
# add individual optimizer choice
361+
optimizers = {}
362+
if 'individual_optimizer' not in self.optimizer:
363+
parameters = OptimizableParameters()
364+
for key, value in optim_param.items():
365+
fit_flag = kwargs.pop(f'fit_{key}', True)
366+
parameters.add_param(key=key, param=value, fit_param=fit_flag)
367+
optimizers['default_optimizer'] = build_optimizer(
368+
parameters, self.optimizer)
369+
else:
370+
# set an individual optimizer if optimizer config
371+
# is given and fit_{key} is True
372+
# update with the default optimizer or ignore otherwise
373+
# | {key}_opt_config | fit_{key} | optimizer |
374+
# | -----------------| ------------| --------------------|
375+
# | True | True | {key}_optimizer |
376+
# | False | True | default_optimizer |
377+
# | True | False | ignore |
378+
# | False | False | ignore |
379+
self.individual_optimizer = True
380+
_optim_param = optim_param.copy()
381+
for key in list(_optim_param.keys()):
382+
parameters = OptimizableParameters()
383+
fit_flag = kwargs.pop(f'fit_{key}', False)
384+
if f'{key}_optimizer' in self.optimizer.keys() and fit_flag:
385+
value = _optim_param.pop(key)
386+
parameters.add_param(
387+
key=key, param=value, fit_param=fit_flag)
388+
optimizers[key] = build_optimizer(
389+
parameters, self.optimizer[f'{key}_optimizer'])
390+
self.logger.info(f'Add an individual optimizer for {key}')
391+
elif not fit_flag:
392+
_optim_param.pop(key)
393+
else:
394+
self.logger.info(f'No optimizer defined for {key}, '
395+
'get the default optimizer')
396+
397+
if len(_optim_param) > 0:
398+
parameters = OptimizableParameters()
399+
if 'default_optimizer' not in self.optimizer:
400+
self.logger.error(
401+
'Individual optimizer mode is selected but '
402+
'some optimizers are not defined. '
403+
'Please set the default_optimzier or set optimizer '
404+
f'for {_optim_param.keys()}.')
405+
raise KeyError
406+
else:
407+
for key in list(_optim_param.keys()):
408+
fit_flag = kwargs.pop(f'fit_{key}', True)
409+
value = _optim_param.pop(key)
410+
if fit_flag:
411+
parameters.add_param(
412+
key=key, param=value, fit_param=fit_flag)
413+
optimizers['default_optimizer'] = build_optimizer(
414+
parameters, self.optimizer['default_optimizer'])
415+
416+
previous_loss = None
362417
for iter_idx in range(n_iter):
363-
364-
def closure():
365-
optimizer.zero_grad()
366-
betas_video = self.__expand_betas__(
367-
batch_size=optim_param['body_pose'].shape[0],
368-
betas=optim_param['betas'])
369-
expanded_param = {}
370-
expanded_param.update(optim_param)
371-
expanded_param['betas'] = betas_video
372-
loss_dict = self.evaluate(
373-
input_list=input_list,
374-
optim_param=expanded_param,
375-
use_shoulder_hip_only=use_shoulder_hip_only,
376-
body_weight=body_weight,
377-
**kwargs)
378-
379-
loss = loss_dict['total_loss']
380-
loss.backward()
381-
return loss
382-
383-
loss = optimizer.step(closure)
384-
if iter_idx > 0 and pre_loss is not None and ftol > 0:
418+
for optimizer_key, optimizer in optimizers.items():
419+
420+
def closure():
421+
optimizer.zero_grad()
422+
423+
betas_video = self.__expand_betas__(
424+
batch_size=optim_param['body_pose'].shape[0],
425+
betas=optim_param['betas'])
426+
expanded_param = {}
427+
expanded_param.update(optim_param)
428+
expanded_param['betas'] = betas_video
429+
loss_dict = self.evaluate(
430+
input_list=input_list,
431+
optim_param=expanded_param,
432+
use_shoulder_hip_only=use_shoulder_hip_only,
433+
body_weight=body_weight,
434+
**kwargs)
435+
436+
if optimizer_key not in loss_dict.keys():
437+
self.logger.error(
438+
f'Individual optimizer is set for {optimizer_key}'
439+
'but there is no loss calculated for this '
440+
'optimizer. Please check LOSS_MAPPING and '
441+
'make sure respective losses are turned on.')
442+
raise KeyError
443+
loss = loss_dict[optimizer_key]
444+
total_loss = loss_dict['total_loss']
445+
446+
loss.backward(retain_graph=True)
447+
448+
torch.nn.utils.clip_grad_norm_(
449+
parameters=optim_param.values(),
450+
max_norm=self.grad_clip)
451+
452+
return total_loss
453+
454+
total_loss = optimizer.step(closure)
455+
456+
if iter_idx > 0 and previous_loss is not None and ftol > 0:
385457
loss_rel_change = self.__compute_relative_change__(
386-
pre_loss, loss.item())
458+
previous_loss, total_loss.item())
387459
if loss_rel_change < ftol:
388460
if self.verbose:
389461
self.logger.info(
390462
f'[ftol={ftol}] Early stop at {iter_idx} iter!')
391463
break
392-
pre_loss = loss.item()
464+
previous_loss = total_loss.item()
393465

394466
stage_config = dict(
395467
use_shoulder_hip_only=use_shoulder_hip_only,
@@ -611,18 +683,22 @@ def __compute_loss__(self,
611683
loss_tensor = handler(**handler_input)
612684
# if loss computed, record it in losses
613685
if loss_tensor is not None:
686+
if loss_tensor.ndim == 3:
687+
loss_tensor = loss_tensor.sum(dim=(2, 1))
688+
elif loss_tensor.ndim == 2:
689+
loss_tensor = loss_tensor.sum(dim=-1)
614690
losses[handler_key] = loss_tensor
615691

616692
total_loss = 0
617693
for key, loss in losses.items():
618-
if loss.ndim == 3:
619-
total_loss = total_loss + loss.sum(dim=(2, 1))
620-
elif loss.ndim == 2:
621-
total_loss = total_loss + loss.sum(dim=-1)
622-
else:
623-
total_loss = total_loss + loss
694+
total_loss = total_loss + loss
624695
losses['total_loss'] = total_loss
625696

697+
if self.individual_optimizer:
698+
losses = self._post_process_loss(losses)
699+
else:
700+
losses['default_optimizer'] = total_loss
701+
626702
# warn once if there's item still in popped kwargs
627703
if not self.__stage_kwargs_warned__ and \
628704
len(kwargs) > 0:
@@ -637,6 +713,28 @@ def __compute_loss__(self,
637713

638714
return losses
639715

716+
def _post_process_loss(self, losses: dict, **kwargs) -> dict:
717+
"""Process losses and map the losses to respective parameters.
718+
719+
Args:
720+
losses (dict): Original loss, use handler_key as keys.
721+
722+
Returns:
723+
dict: Processed loss, use parameter names as keys.
724+
Original keys included.
725+
"""
726+
727+
for loss_key in list(losses.keys()):
728+
process_list = LOSS_MAPPING.get(loss_key, [])
729+
for optimizer_loss in process_list:
730+
losses[optimizer_loss] = losses[optimizer_loss] + \
731+
losses[loss_key] if optimizer_loss in losses \
732+
else losses[loss_key]
733+
734+
losses['default_optimizer'] = losses['total_loss']
735+
736+
return losses
737+
640738
def __match_init_batch_size__(self, init_param: torch.Tensor,
641739
default_param: torch.Tensor,
642740
batch_size: int) -> torch.Tensor:

0 commit comments

Comments
 (0)