Skip to content

Commit 408001c

Browse files
committed
Update normal version
1 parent 15e109b commit 408001c

File tree

5 files changed

+14
-38
lines changed

5 files changed

+14
-38
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ TODO LIST:
1919
- [ ] Add synthetic sentences based on other source of information
2020
- [ ] Maybe use LLM to augment the reports
2121
- [ ] Add warmup time for the diffusion model
22+
- [ ] Include images from ChestX-ray14 https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345
2223

2324

2425
## C1

configs/stage1/aekl_v0.yaml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ stage1:
88
spatial_dims: 2
99
in_channels: 1
1010
out_channels: 1
11-
num_channels: [64, 128, 128, 128]
11+
num_channels: [64, 128, 128, 256]
1212
latent_channels: 3
1313
num_res_blocks: 2
1414
attention_levels: [False, False, False, False]
@@ -22,11 +22,6 @@ discriminator:
2222
num_layers_d: 3
2323
in_channels: 1
2424
out_channels: 1
25-
kernel_size: 4
26-
activation: "LEAKYRELU"
27-
norm: "BATCH"
28-
bias: False
29-
padding: 1
3025

3126
perceptual_network:
3227
params:

src/python/testing/generate_sample_local.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,13 @@
103103

104104
plt.imshow(sample.cpu()[0, 0, :, :], cmap="gray", vmin=0, vmax=1)
105105
plt.show()
106+
107+
108+
torch.save(
109+
diffusion.state_dict(),
110+
"/media/walter/Storage/Projects/GenerativeModels/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/models/diffusion_model.pth",
111+
)
112+
torch.save(
113+
stage1.state_dict(),
114+
"/media/walter/Storage/Projects/GenerativeModels/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/models/autoencoder.pth",
115+
)

src/python/training/training_functions_old_disc.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,13 @@
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
7-
from pynvml.smi import nvidia_smi
87
from tensorboardX import SummaryWriter
98
from torch.cuda.amp import GradScaler, autocast
109
from tqdm import tqdm
10+
from training_functions import get_lr, print_gpu_memory_report
1111
from util import log_reconstructions
1212

1313

14-
def get_lr(optimizer):
15-
for param_group in optimizer.param_groups:
16-
return param_group["lr"]
17-
18-
19-
def print_gpu_memory_report():
20-
if torch.cuda.is_available():
21-
nvsmi = nvidia_smi.getInstance()
22-
data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"]
23-
print("Memory report")
24-
for i, data_by_rank in enumerate(data):
25-
mem_report = data_by_rank["fb_memory_usage"]
26-
print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}")
27-
28-
2914
# ----------------------------------------------------------------------------------------------------------------------
3015
# AUTOENCODER KL
3116
# ----------------------------------------------------------------------------------------------------------------------

src/python/training/training_functions_original_disc.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,13 @@
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
7-
from pynvml.smi import nvidia_smi
87
from tensorboardX import SummaryWriter
98
from torch.cuda.amp import GradScaler, autocast
109
from tqdm import tqdm
10+
from training_functions import get_lr, print_gpu_memory_report
1111
from util import log_reconstructions
1212

1313

14-
def get_lr(optimizer):
15-
for param_group in optimizer.param_groups:
16-
return param_group["lr"]
17-
18-
19-
def print_gpu_memory_report():
20-
if torch.cuda.is_available():
21-
nvsmi = nvidia_smi.getInstance()
22-
data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"]
23-
print("Memory report")
24-
for i, data_by_rank in enumerate(data):
25-
mem_report = data_by_rank["fb_memory_usage"]
26-
print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}")
27-
28-
2914
def hinge_d_loss(logits_real, logits_fake):
3015
loss_real = torch.mean(F.relu(1.0 - logits_real))
3116
loss_fake = torch.mean(F.relu(1.0 + logits_fake))

0 commit comments

Comments
 (0)