12
12
SMPLifyBaseHook , build_smplify_hook ,
13
13
)
14
14
from xrmocap .model .body_model .builder import build_body_model
15
+ from xrmocap .model .loss .mapping import LOSS_MAPPING
15
16
from xrmocap .transform .convention .keypoints_convention import ( # noqa:E501
16
17
get_keypoint_idx , get_keypoint_idxs_by_part ,
17
18
)
@@ -42,6 +43,7 @@ def __init__(self,
42
43
hooks : List [Union [dict , SMPLifyBaseHook ]] = [],
43
44
verbose : bool = False ,
44
45
info_level : Literal ['stage' , 'step' ] = 'step' ,
46
+ grad_clip : float = 1.0 ,
45
47
logger : Union [None , str , logging .Logger ] = None ) -> None :
46
48
"""Re-implementation of SMPLify with extended features.
47
49
@@ -92,7 +94,9 @@ def __init__(self,
92
94
self .device = device
93
95
self .stage_config = stages
94
96
self .optimizer = optimizer
97
+ self .grad_clip = grad_clip
95
98
self .hooks = []
99
+ self .individual_optimizer = False
96
100
97
101
# initialize body model
98
102
if isinstance (body_model , dict ):
@@ -352,44 +356,112 @@ def __optimize_stage__(self,
352
356
self .call_hook ('before_stage' , ** hook_kwargs )
353
357
354
358
kwargs = kwargs .copy ()
355
- parameters = OptimizableParameters ()
356
- for key , value in optim_param .items ():
357
- fit_flag = kwargs .pop (f'fit_{ key } ' , True )
358
- parameters .add_param (key = key , param = value , fit_param = fit_flag )
359
- optimizer = build_optimizer (parameters , self .optimizer )
360
359
361
- pre_loss = None
360
+ # add individual optimizer choice
361
+ optimizers = {}
362
+ if 'individual_optimizer' not in self .optimizer :
363
+ parameters = OptimizableParameters ()
364
+ for key , value in optim_param .items ():
365
+ fit_flag = kwargs .pop (f'fit_{ key } ' , True )
366
+ parameters .add_param (key = key , param = value , fit_param = fit_flag )
367
+ optimizers ['default_optimizer' ] = build_optimizer (
368
+ parameters , self .optimizer )
369
+ else :
370
+ # set an individual optimizer if optimizer config
371
+ # is given and fit_{key} is True
372
+ # update with the default optimizer or ignore otherwise
373
+ # | {key}_opt_config | fit_{key} | optimizer |
374
+ # | -----------------| ------------| --------------------|
375
+ # | True | True | {key}_optimizer |
376
+ # | False | True | default_optimizer |
377
+ # | True | False | ignore |
378
+ # | False | False | ignore |
379
+ self .individual_optimizer = True
380
+ _optim_param = optim_param .copy ()
381
+ for key in list (_optim_param .keys ()):
382
+ parameters = OptimizableParameters ()
383
+ fit_flag = kwargs .pop (f'fit_{ key } ' , False )
384
+ if f'{ key } _optimizer' in self .optimizer .keys () and fit_flag :
385
+ value = _optim_param .pop (key )
386
+ parameters .add_param (
387
+ key = key , param = value , fit_param = fit_flag )
388
+ optimizers [key ] = build_optimizer (
389
+ parameters , self .optimizer [f'{ key } _optimizer' ])
390
+ self .logger .info (f'Add an individual optimizer for { key } ' )
391
+ elif not fit_flag :
392
+ _optim_param .pop (key )
393
+ else :
394
+ self .logger .info (f'No optimizer defined for { key } , '
395
+ 'get the default optimizer' )
396
+
397
+ if len (_optim_param ) > 0 :
398
+ parameters = OptimizableParameters ()
399
+ if 'default_optimizer' not in self .optimizer :
400
+ self .logger .error (
401
+ 'Individual optimizer mode is selected but '
402
+ 'some optimizers are not defined. '
403
+ 'Please set the default_optimzier or set optimizer '
404
+ f'for { _optim_param .keys ()} .' )
405
+ raise KeyError
406
+ else :
407
+ for key in list (_optim_param .keys ()):
408
+ fit_flag = kwargs .pop (f'fit_{ key } ' , True )
409
+ value = _optim_param .pop (key )
410
+ if fit_flag :
411
+ parameters .add_param (
412
+ key = key , param = value , fit_param = fit_flag )
413
+ optimizers ['default_optimizer' ] = build_optimizer (
414
+ parameters , self .optimizer ['default_optimizer' ])
415
+
416
+ previous_loss = None
362
417
for iter_idx in range (n_iter ):
363
-
364
- def closure ():
365
- optimizer .zero_grad ()
366
- betas_video = self .__expand_betas__ (
367
- batch_size = optim_param ['body_pose' ].shape [0 ],
368
- betas = optim_param ['betas' ])
369
- expanded_param = {}
370
- expanded_param .update (optim_param )
371
- expanded_param ['betas' ] = betas_video
372
- loss_dict = self .evaluate (
373
- input_list = input_list ,
374
- optim_param = expanded_param ,
375
- use_shoulder_hip_only = use_shoulder_hip_only ,
376
- body_weight = body_weight ,
377
- ** kwargs )
378
-
379
- loss = loss_dict ['total_loss' ]
380
- loss .backward ()
381
- return loss
382
-
383
- loss = optimizer .step (closure )
384
- if iter_idx > 0 and pre_loss is not None and ftol > 0 :
418
+ for optimizer_key , optimizer in optimizers .items ():
419
+
420
+ def closure ():
421
+ optimizer .zero_grad ()
422
+
423
+ betas_video = self .__expand_betas__ (
424
+ batch_size = optim_param ['body_pose' ].shape [0 ],
425
+ betas = optim_param ['betas' ])
426
+ expanded_param = {}
427
+ expanded_param .update (optim_param )
428
+ expanded_param ['betas' ] = betas_video
429
+ loss_dict = self .evaluate (
430
+ input_list = input_list ,
431
+ optim_param = expanded_param ,
432
+ use_shoulder_hip_only = use_shoulder_hip_only ,
433
+ body_weight = body_weight ,
434
+ ** kwargs )
435
+
436
+ if optimizer_key not in loss_dict .keys ():
437
+ self .logger .error (
438
+ f'Individual optimizer is set for { optimizer_key } '
439
+ 'but there is no loss calculated for this '
440
+ 'optimizer. Please check LOSS_MAPPING and '
441
+ 'make sure respective losses are turned on.' )
442
+ raise KeyError
443
+ loss = loss_dict [optimizer_key ]
444
+ total_loss = loss_dict ['total_loss' ]
445
+
446
+ loss .backward (retain_graph = True )
447
+
448
+ torch .nn .utils .clip_grad_norm_ (
449
+ parameters = optim_param .values (),
450
+ max_norm = self .grad_clip )
451
+
452
+ return total_loss
453
+
454
+ total_loss = optimizer .step (closure )
455
+
456
+ if iter_idx > 0 and previous_loss is not None and ftol > 0 :
385
457
loss_rel_change = self .__compute_relative_change__ (
386
- pre_loss , loss .item ())
458
+ previous_loss , total_loss .item ())
387
459
if loss_rel_change < ftol :
388
460
if self .verbose :
389
461
self .logger .info (
390
462
f'[ftol={ ftol } ] Early stop at { iter_idx } iter!' )
391
463
break
392
- pre_loss = loss .item ()
464
+ previous_loss = total_loss .item ()
393
465
394
466
stage_config = dict (
395
467
use_shoulder_hip_only = use_shoulder_hip_only ,
@@ -611,18 +683,22 @@ def __compute_loss__(self,
611
683
loss_tensor = handler (** handler_input )
612
684
# if loss computed, record it in losses
613
685
if loss_tensor is not None :
686
+ if loss_tensor .ndim == 3 :
687
+ loss_tensor = loss_tensor .sum (dim = (2 , 1 ))
688
+ elif loss_tensor .ndim == 2 :
689
+ loss_tensor = loss_tensor .sum (dim = - 1 )
614
690
losses [handler_key ] = loss_tensor
615
691
616
692
total_loss = 0
617
693
for key , loss in losses .items ():
618
- if loss .ndim == 3 :
619
- total_loss = total_loss + loss .sum (dim = (2 , 1 ))
620
- elif loss .ndim == 2 :
621
- total_loss = total_loss + loss .sum (dim = - 1 )
622
- else :
623
- total_loss = total_loss + loss
694
+ total_loss = total_loss + loss
624
695
losses ['total_loss' ] = total_loss
625
696
697
+ if self .individual_optimizer :
698
+ losses = self ._post_process_loss (losses )
699
+ else :
700
+ losses ['default_optimizer' ] = total_loss
701
+
626
702
# warn once if there's item still in popped kwargs
627
703
if not self .__stage_kwargs_warned__ and \
628
704
len (kwargs ) > 0 :
@@ -637,6 +713,28 @@ def __compute_loss__(self,
637
713
638
714
return losses
639
715
716
+ def _post_process_loss (self , losses : dict , ** kwargs ) -> dict :
717
+ """Process losses and map the losses to respective parameters.
718
+
719
+ Args:
720
+ losses (dict): Original loss, use handler_key as keys.
721
+
722
+ Returns:
723
+ dict: Processed loss, use parameter names as keys.
724
+ Original keys included.
725
+ """
726
+
727
+ for loss_key in list (losses .keys ()):
728
+ process_list = LOSS_MAPPING .get (loss_key , [])
729
+ for optimizer_loss in process_list :
730
+ losses [optimizer_loss ] = losses [optimizer_loss ] + \
731
+ losses [loss_key ] if optimizer_loss in losses \
732
+ else losses [loss_key ]
733
+
734
+ losses ['default_optimizer' ] = losses ['total_loss' ]
735
+
736
+ return losses
737
+
640
738
def __match_init_batch_size__ (self , init_param : torch .Tensor ,
641
739
default_param : torch .Tensor ,
642
740
batch_size : int ) -> torch .Tensor :
0 commit comments