-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
您好,关于对抗训练fgm的代码:
if args.adv_fgm:
fgm.attack() # 在embedding上添加对抗扰动
loss_adv = model(**inputs)[0]
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
fgm.restore() # 恢复embedding参数
如果使用多卡的话,例如n_gpu=2
,这里的loss_adv
是不是要取一下平均,也就是loss_adv=loss_adv.mean()
。期待您的回复,谢谢!
Metadata
Metadata
Assignees
Labels
No labels