4
4
import torch
5
5
import torch .nn as nn
6
6
import torch .nn .functional as F
7
+ from generative .losses .adversarial_loss import PatchAdversarialLoss
7
8
from pynvml .smi import nvidia_smi
8
9
from tensorboardX import SummaryWriter
9
10
from torch .cuda .amp import GradScaler , autocast
@@ -143,6 +144,8 @@ def train_epoch_aekl(
143
144
model .train ()
144
145
discriminator .train ()
145
146
147
+ adv_loss = PatchAdversarialLoss (criterion = "least_squares" , no_activation_leastsq = True )
148
+
146
149
pbar = tqdm (enumerate (loader ), total = len (loader ))
147
150
for step , x in pbar :
148
151
images = x ["image" ].to (device )
@@ -157,9 +160,11 @@ def train_epoch_aekl(
157
160
kl_loss = 0.5 * torch .sum (z_mu .pow (2 ) + z_sigma .pow (2 ) - torch .log (z_sigma .pow (2 )) - 1 , dim = [1 , 2 , 3 ])
158
161
kl_loss = torch .sum (kl_loss ) / kl_loss .shape [0 ]
159
162
160
- logits_fake = discriminator (reconstruction .contiguous ().float ())[- 1 ]
161
- real_label = torch .ones_like (logits_fake , device = logits_fake .device )
162
- generator_loss = F .mse_loss (logits_fake , real_label )
163
+ if adv_weight > 0 :
164
+ logits_fake = discriminator (reconstruction .contiguous ().float ())[- 1 ]
165
+ generator_loss = adv_loss (logits_fake , target_is_real = True , for_discriminator = False )
166
+ else :
167
+ generator_loss = torch .tensor ([0.0 ]).to (device )
163
168
164
169
loss = l1_loss + kl_weight * kl_loss + perceptual_weight * p_loss + adv_weight * generator_loss
165
170
@@ -184,25 +189,26 @@ def train_epoch_aekl(
184
189
scaler_g .update ()
185
190
186
191
# DISCRIMINATOR
187
- optimizer_d .zero_grad (set_to_none = True )
188
-
189
- with autocast (enabled = True ):
190
- logits_fake = discriminator (reconstruction .contiguous ().detach ())[- 1 ]
191
- fake_label = torch .zeros_like (logits_fake , device = logits_fake .device )
192
- loss_d_fake = F .mse_loss (logits_fake , fake_label )
193
- logits_real = discriminator (images .contiguous ().detach ())[- 1 ]
194
- real_label = torch .ones_like (logits_real , device = logits_real .device )
195
- loss_d_real = F .mse_loss (logits_real , real_label )
196
- discriminator_loss = (loss_d_fake + loss_d_real ) * 0.5
197
-
198
- d_loss = adv_weight * discriminator_loss
199
- d_loss = d_loss .mean ()
200
-
201
- scaler_d .scale (d_loss ).backward ()
202
- scaler_d .unscale_ (optimizer_d )
203
- torch .nn .utils .clip_grad_norm_ (discriminator .parameters (), 1 )
204
- scaler_d .step (optimizer_d )
205
- scaler_d .update ()
192
+ if adv_weight > 0 :
193
+ optimizer_d .zero_grad (set_to_none = True )
194
+
195
+ with autocast (enabled = True ):
196
+ logits_fake = discriminator (reconstruction .contiguous ().detach ())[- 1 ]
197
+ loss_d_fake = adv_loss (logits_fake , target_is_real = False , for_discriminator = True )
198
+ logits_real = discriminator (images .contiguous ().detach ())[- 1 ]
199
+ loss_d_real = adv_loss (logits_real , target_is_real = True , for_discriminator = True )
200
+ discriminator_loss = (loss_d_fake + loss_d_real ) * 0.5
201
+
202
+ d_loss = adv_weight * discriminator_loss
203
+ d_loss = d_loss .mean ()
204
+
205
+ scaler_d .scale (d_loss ).backward ()
206
+ scaler_d .unscale_ (optimizer_d )
207
+ torch .nn .utils .clip_grad_norm_ (discriminator .parameters (), 1 )
208
+ scaler_d .step (optimizer_d )
209
+ scaler_d .update ()
210
+ else :
211
+ discriminator_loss = torch .tensor ([0.0 ]).to (device )
206
212
207
213
losses ["d_loss" ] = discriminator_loss
208
214
@@ -241,6 +247,7 @@ def eval_aekl(
241
247
model .eval ()
242
248
discriminator .eval ()
243
249
250
+ adv_loss = PatchAdversarialLoss (criterion = "least_squares" , no_activation_leastsq = True )
244
251
total_losses = OrderedDict ()
245
252
for x in loader :
246
253
images = x ["image" ].to (device )
@@ -250,20 +257,24 @@ def eval_aekl(
250
257
reconstruction , z_mu , z_sigma = model (x = images )
251
258
l1_loss = F .l1_loss (reconstruction .float (), images .float ())
252
259
p_loss = perceptual_loss (reconstruction .float (), images .float ())
253
- kl_loss = 0.5 * torch .sum (z_mu .pow (2 ) + z_sigma .pow (2 ) - torch .log (z_sigma .pow (2 )) - 1 , dim = [1 , 2 , 3 ])
260
+ kl_loss = 0.5 * torch .sum (z_mu .pow (2 ) + z_sigma .pow (2 ) - torch .log (z_sigma .pow (2 )) - 1 , dim = [1 , 2 , 3 , 4 ])
254
261
kl_loss = torch .sum (kl_loss ) / kl_loss .shape [0 ]
255
- logits_fake = discriminator (reconstruction .contiguous ().float ())[- 1 ]
256
- real_label = torch .ones_like (logits_fake , device = logits_fake .device )
257
- generator_loss = F .mse_loss (logits_fake , real_label )
262
+
263
+ if adv_weight > 0 :
264
+ logits_fake = discriminator (reconstruction .contiguous ().float ())[- 1 ]
265
+ generator_loss = adv_loss (logits_fake , target_is_real = True , for_discriminator = False )
266
+ else :
267
+ generator_loss = torch .tensor ([0.0 ]).to (device )
258
268
259
269
# DISCRIMINATOR
260
- logits_fake = discriminator (reconstruction .contiguous ().detach ())[- 1 ]
261
- fake_label = torch .zeros_like (logits_fake , device = logits_fake .device )
262
- loss_d_fake = F .mse_loss (logits_fake , fake_label )
263
- logits_real = discriminator (images .contiguous ().detach ())[- 1 ]
264
- real_label = torch .ones_like (logits_real , device = logits_real .device )
265
- loss_d_real = F .mse_loss (logits_real , real_label )
266
- discriminator_loss = (loss_d_fake + loss_d_real ) * 0.5
270
+ if adv_weight > 0 :
271
+ logits_fake = discriminator (reconstruction .contiguous ().detach ())[- 1 ]
272
+ loss_d_fake = adv_loss (logits_fake , target_is_real = False , for_discriminator = True )
273
+ logits_real = discriminator (images .contiguous ().detach ())[- 1 ]
274
+ loss_d_real = adv_loss (logits_real , target_is_real = True , for_discriminator = True )
275
+ discriminator_loss = (loss_d_fake + loss_d_real ) * 0.5
276
+ else :
277
+ discriminator_loss = torch .tensor ([0.0 ]).to (device )
267
278
268
279
loss = l1_loss + kl_weight * kl_loss + perceptual_weight * p_loss + adv_weight * generator_loss
269
280
0 commit comments