Skip to content

Commit 15e109b

Browse files
committed
Add sampling
1 parent 220ce68 commit 15e109b

File tree

3 files changed

+158
-33
lines changed

3 files changed

+158
-33
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,12 @@ TODO LIST:
1919
- [ ] Add synthetic sentences based on other source of information
2020
- [ ] Maybe use LLM to augment the reports
2121
- [ ] Add warmup time for the diffusion model
22+
23+
24+
## C1
25+
### Uploading dataset
26+
To create dataset for C1, run the following command:
27+
28+
```
29+
ngc dataset upload -y --desc "MIMIC dataset with dimension 512x512." --source /nfs/home/wds20/datasets/MIMIC-CXR-JPG_v2.0.0 --threads 12 scotheart
30+
```
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import matplotlib.pyplot as plt
2+
import mlflow.pytorch
3+
import torch
4+
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet
5+
from generative.networks.schedulers import DDIMScheduler
6+
from monai.config import print_config
7+
from monai.utils import set_determinism
8+
from tqdm import tqdm
9+
from transformers import CLIPTextModel, CLIPTokenizer
10+
11+
seed = 42
12+
set_determinism(seed=seed)
13+
print_config()
14+
15+
# output_dir = Path("/media/walter/Storage/Projects/generative_cardiac/outputs/figures/same_seed")
16+
# output_dir.mkdir(exist_ok=True, parents=True)
17+
#
18+
stage1_old = mlflow.pytorch.load_model(
19+
"/media/walter/Storage/Projects/generative_mimic/mlruns/398344666374521908/6f280de5aa634aab96e6c31eed22a62b/artifacts/final_model"
20+
)
21+
stage1 = AutoencoderKL(
22+
spatial_dims=2,
23+
in_channels=1,
24+
out_channels=1,
25+
num_channels=[64, 128, 128, 128],
26+
latent_channels=3,
27+
num_res_blocks=2,
28+
attention_levels=[False, False, False, False],
29+
with_encoder_nonlocal_attn=True,
30+
with_decoder_nonlocal_attn=True,
31+
)
32+
stage1.load_state_dict(stage1_old.state_dict())
33+
stage1.eval()
34+
del stage1_old
35+
36+
diffusion_old = mlflow.pytorch.load_model(
37+
"/media/walter/Storage/Projects/generative_mimic/mlruns/411881789846457862/6f1d5a773cf5421aadd7ff787bfe7643/artifacts/final_model"
38+
)
39+
diffusion = DiffusionModelUNet(
40+
spatial_dims=2,
41+
in_channels=3,
42+
out_channels=3,
43+
num_res_blocks=2,
44+
num_channels=[256, 512, 768],
45+
attention_levels=[False, True, True],
46+
with_conditioning=True,
47+
cross_attention_dim=1024,
48+
num_head_channels=[0, 512, 768],
49+
)
50+
diffusion.load_state_dict(diffusion_old.state_dict())
51+
diffusion.eval()
52+
del diffusion_old
53+
54+
55+
device = torch.device("cuda")
56+
diffusion = diffusion.to(device)
57+
stage1 = stage1.to(device)
58+
59+
scheduler = DDIMScheduler(
60+
num_train_timesteps=1000,
61+
beta_start=0.0015,
62+
beta_end=0.0205,
63+
beta_schedule="scaled_linear",
64+
prediction_type="v_prediction",
65+
clip_sample=False,
66+
)
67+
scheduler.set_timesteps(200)
68+
69+
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder")
70+
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer")
71+
72+
prompt = ["", "small right-sided pleural effusion"]
73+
text_inputs = tokenizer(
74+
prompt,
75+
padding="max_length",
76+
max_length=tokenizer.model_max_length,
77+
truncation=True,
78+
return_tensors="pt",
79+
)
80+
text_input_ids = text_inputs.input_ids
81+
82+
prompt_embeds = text_encoder(text_input_ids.squeeze(1))
83+
prompt_embeds = prompt_embeds[0].to(device)
84+
85+
guidance_scale = 7.0
86+
noise = torch.randn((1, 3, 64, 64)).to(device)
87+
88+
with torch.no_grad():
89+
progress_bar = tqdm(scheduler.timesteps)
90+
for t in progress_bar:
91+
noise_input = torch.cat([noise] * 2)
92+
model_output = diffusion(
93+
noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds
94+
)
95+
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
96+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
97+
98+
noise, _ = scheduler.step(noise_pred, t, noise)
99+
100+
with torch.no_grad():
101+
sample = stage1.decode_stage_2_outputs(noise / 0.3)
102+
103+
104+
plt.imshow(sample.cpu()[0, 0, :, :], cmap="gray", vmin=0, vmax=1)
105+
plt.show()

src/python/training/training_functions.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
7+
from generative.losses.adversarial_loss import PatchAdversarialLoss
78
from pynvml.smi import nvidia_smi
89
from tensorboardX import SummaryWriter
910
from torch.cuda.amp import GradScaler, autocast
@@ -143,6 +144,8 @@ def train_epoch_aekl(
143144
model.train()
144145
discriminator.train()
145146

147+
adv_loss = PatchAdversarialLoss(criterion="least_squares", no_activation_leastsq=True)
148+
146149
pbar = tqdm(enumerate(loader), total=len(loader))
147150
for step, x in pbar:
148151
images = x["image"].to(device)
@@ -157,9 +160,11 @@ def train_epoch_aekl(
157160
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
158161
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
159162

160-
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
161-
real_label = torch.ones_like(logits_fake, device=logits_fake.device)
162-
generator_loss = F.mse_loss(logits_fake, real_label)
163+
if adv_weight > 0:
164+
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
165+
generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
166+
else:
167+
generator_loss = torch.tensor([0.0]).to(device)
163168

164169
loss = l1_loss + kl_weight * kl_loss + perceptual_weight * p_loss + adv_weight * generator_loss
165170

@@ -184,25 +189,26 @@ def train_epoch_aekl(
184189
scaler_g.update()
185190

186191
# DISCRIMINATOR
187-
optimizer_d.zero_grad(set_to_none=True)
188-
189-
with autocast(enabled=True):
190-
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
191-
fake_label = torch.zeros_like(logits_fake, device=logits_fake.device)
192-
loss_d_fake = F.mse_loss(logits_fake, fake_label)
193-
logits_real = discriminator(images.contiguous().detach())[-1]
194-
real_label = torch.ones_like(logits_real, device=logits_real.device)
195-
loss_d_real = F.mse_loss(logits_real, real_label)
196-
discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
197-
198-
d_loss = adv_weight * discriminator_loss
199-
d_loss = d_loss.mean()
200-
201-
scaler_d.scale(d_loss).backward()
202-
scaler_d.unscale_(optimizer_d)
203-
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
204-
scaler_d.step(optimizer_d)
205-
scaler_d.update()
192+
if adv_weight > 0:
193+
optimizer_d.zero_grad(set_to_none=True)
194+
195+
with autocast(enabled=True):
196+
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
197+
loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
198+
logits_real = discriminator(images.contiguous().detach())[-1]
199+
loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
200+
discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
201+
202+
d_loss = adv_weight * discriminator_loss
203+
d_loss = d_loss.mean()
204+
205+
scaler_d.scale(d_loss).backward()
206+
scaler_d.unscale_(optimizer_d)
207+
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
208+
scaler_d.step(optimizer_d)
209+
scaler_d.update()
210+
else:
211+
discriminator_loss = torch.tensor([0.0]).to(device)
206212

207213
losses["d_loss"] = discriminator_loss
208214

@@ -241,6 +247,7 @@ def eval_aekl(
241247
model.eval()
242248
discriminator.eval()
243249

250+
adv_loss = PatchAdversarialLoss(criterion="least_squares", no_activation_leastsq=True)
244251
total_losses = OrderedDict()
245252
for x in loader:
246253
images = x["image"].to(device)
@@ -250,20 +257,24 @@ def eval_aekl(
250257
reconstruction, z_mu, z_sigma = model(x=images)
251258
l1_loss = F.l1_loss(reconstruction.float(), images.float())
252259
p_loss = perceptual_loss(reconstruction.float(), images.float())
253-
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
260+
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
254261
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
255-
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
256-
real_label = torch.ones_like(logits_fake, device=logits_fake.device)
257-
generator_loss = F.mse_loss(logits_fake, real_label)
262+
263+
if adv_weight > 0:
264+
logits_fake = discriminator(reconstruction.contiguous().float())[-1]
265+
generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
266+
else:
267+
generator_loss = torch.tensor([0.0]).to(device)
258268

259269
# DISCRIMINATOR
260-
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
261-
fake_label = torch.zeros_like(logits_fake, device=logits_fake.device)
262-
loss_d_fake = F.mse_loss(logits_fake, fake_label)
263-
logits_real = discriminator(images.contiguous().detach())[-1]
264-
real_label = torch.ones_like(logits_real, device=logits_real.device)
265-
loss_d_real = F.mse_loss(logits_real, real_label)
266-
discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
270+
if adv_weight > 0:
271+
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
272+
loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
273+
logits_real = discriminator(images.contiguous().detach())[-1]
274+
loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
275+
discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
276+
else:
277+
discriminator_loss = torch.tensor([0.0]).to(device)
267278

268279
loss = l1_loss + kl_weight * kl_loss + perceptual_weight * p_loss + adv_weight * generator_loss
269280

0 commit comments

Comments
 (0)