Skip to content

Commit 3b62d62

Browse files
committed
[release] new code structure, better auto lens design, coherent ray tracing
1 parent 489eeda commit 3b62d62

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+8401
-1929
lines changed

0_hello_deeplens.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
In this code, we will load a lens from a file. Then we will plot the lens setup and render a sample image.
55
66
Technical Paper:
7-
Yang, Xinge and Fu, Qiang and Heidrich, Wolfgang, "Curriculum learning for ab initio deep learned refractive optics," ArXiv preprint (2023)
7+
[1] Xinge Yang, Qiang Fu and Wolfgang Heidrich, "Curriculum learning for ab initio deep learned refractive optics," ArXiv preprint 2023.
8+
[2] Congli Wang, Ni Chen, and Wolfgang Heidrich, "dO: A differentiable engine for Deep Lens design of computational imaging systems," IEEE TCI 2023.
89
910
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
1011
# The license is only for non-commercial use (commercial licenses can be obtained from authors).
1112
# The material is provided as-is, with no warranties whatsoever.
1213
# If you publish any code, data, or scientific work based on this, please cite our work.
1314
"""
14-
from deeplens import Lensgroup
15+
from deeplens import GeoLens
1516

1617
def main():
17-
lens = Lensgroup(filename='./lenses/ef40mm_f2.8.json')
18-
# lens = Lensgroup(filename='./lenses/cellphone68deg.json')
18+
lens = GeoLens(filename='./lenses/camera/ef35mm_f2.0.json')
19+
# lens = GeoLens(filename='./lenses/cellphone/cellphone80deg.json')
1920
lens.analysis(render=True)
2021

2122
if __name__=='__main__':

1_end2end_5lines.py

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
End2End optical design with only 5 lines of code.
33
44
Technical Paper:
5-
Yang, Xinge and Fu, Qiang and Heidrich, Wolfgang, "Curriculum learning for ab initio deep learned refractive optics," ArXiv preprint (2023)
5+
Xinge Yang, Qiang Fu and Wolfgang Heidrich, "Curriculum learning for ab initio deep learned refractive optics," ArXiv preprint 2023.
66
77
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
88
# The license is only for non-commercial use (commercial licenses can be obtained from authors).
@@ -15,7 +15,6 @@
1515
import logging
1616
import random
1717
import string
18-
import argparse
1918
import numpy as np
2019
import cv2 as cv
2120
from tqdm import tqdm
@@ -30,37 +29,44 @@
3029

3130
def config():
3231
# ==> Config
33-
with open('configs/end2end_5lines.yml') as f:
32+
with open('configs/1_end2end_5lines.yml') as f:
3433
args = yaml.load(f, Loader=yaml.FullLoader)
35-
34+
3635
# ==> Result folder
3736
characters = string.ascii_letters + string.digits
3837
random_string = ''.join(random.choice(characters) for i in range(4))
39-
result_dir = f'./results/' + datetime.now().strftime("%m%d-%H%M%S") + '-End2End' + '-' + random_string
40-
args['result_dir'] = result_dir
38+
current_time = datetime.now().strftime("%m%d-%H%M%S")
39+
exp_name = current_time + '-End2End-5-lines-' + random_string
40+
result_dir = f'./results/{exp_name}'
4141
os.makedirs(result_dir, exist_ok=True)
42-
print(f'Result folder: {result_dir}')
42+
args['result_dir'] = result_dir
4343

44-
# ==> Random seed
45-
set_seed(args['train']['seed'])
44+
if args['seed'] is None:
45+
seed = random.randint(0, 100)
46+
args['seed'] = seed
47+
set_seed(args['seed'])
4648

47-
# ==> Logger
49+
# ==> Log
4850
set_logger(result_dir)
49-
# Log to wandb
51+
logging.info(f'EXP: {args["EXP_NAME"]}')
5052
if not args['DEBUG']:
51-
pass
53+
raise Exception('Add your wandb logging config here.')
5254

5355
# ==> Device
5456
num_gpus = torch.cuda.device_count()
5557
args['num_gpus'] = num_gpus
5658
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
57-
logging.info(f'Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)')
5859
args['device'] = device
60+
logging.info(f'Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)')
5961

60-
# ==> Save config
62+
# ==> Save config and original code
6163
with open(f'{result_dir}/config.yml', 'w') as f:
6264
yaml.dump(args, f)
6365

66+
with open(f'{result_dir}/1_end2end_5lines.py', 'w') as f:
67+
with open('1_end2end_5lines.py', 'r') as code:
68+
f.write(code.read())
69+
6470
return args
6571

6672

@@ -80,7 +86,6 @@ def end2end_train(lens, net, args):
8086
# ==> Network optimizer
8187
batchs = len(train_loader)
8288
epochs = args['train']['epochs']
83-
warm_up = args['train']['warm_up']
8489
net_optim = torch.optim.AdamW(net.parameters(), lr=args['network']['lr'], betas=(0.9, 0.98), eps=1e-08)
8590
net_sche = torch.optim.lr_scheduler.CosineAnnealingLR(net_optim, T_max=epochs*batchs, eta_min=0, last_epoch=-1)
8691

@@ -93,43 +98,15 @@ def end2end_train(lens, net, args):
9398
lens_sche= torch.optim.lr_scheduler.CosineAnnealingLR(lens_optim, T_max=epochs*batchs, eta_min=0, last_epoch=-1)
9499

95100
# ==> Criterion
96-
cri_l2 = nn.MSELoss()
101+
cri_l2 = nn.L1Loss()
97102

98-
# ==> Training
103+
# ==> Log
99104
logging.info(f'Start End2End optical design.')
105+
lens.write_lens_json(f'{result_dir}/epoch0.json')
106+
lens.analysis(f'{result_dir}/epoch0', render=False, zmx_format=True)
107+
108+
# ==> Training
100109
for epoch in range(args['train']['epochs'] + 1):
101-
102-
# ==> Evaluate
103-
if epoch % 1 == 0:
104-
net.eval()
105-
with torch.no_grad():
106-
# => Save data and simple evaluation
107-
lens.write_lens_json(f'{result_dir}/epoch{epoch}.json')
108-
lens.analysis(f'{result_dir}/epoch{epoch}', render=False, zmx_format=True)
109-
110-
torch.save(net.state_dict(), f'{result_dir}/net_epoch{epoch}.pth')
111-
112-
# => Qualitative evaluation
113-
img1 = cv.cvtColor(cv.imread(f'./datasets/lena.png'), cv.COLOR_BGR2RGB)
114-
img1 = cv.resize(img1, args['train']['img_res']).astype(np.float32)
115-
img1 = torch.from_numpy(img1/255.).permute(2, 0, 1).unsqueeze(0).to(device)
116-
img1 = normalize_ImageNet_stats(img1)
117-
118-
img1_render = lens.render(img1)
119-
psnr_render = batch_PSNR(img1, img1_render)
120-
ssim_render = batch_SSIM(img1, img1_render)
121-
save_image(de_normalize(img1_render), f'{result_dir}/img1_render_epoch{epoch}.png')
122-
img1_rec = net(img1_render)
123-
psnr_rec = batch_PSNR(img1, img1_rec)
124-
ssim_rec = batch_SSIM(img1, img1_rec)
125-
save_image(de_normalize(img1_rec), f'{result_dir}/img1_rec_epoch{epoch}.png')
126-
127-
logging.info(f'Epoch [{epoch}/{args["train"]["epochs"]}], PSNR_render: {psnr_render:.4f}, SSIM_render: {ssim_render:.4f}, PSNR_rec: {psnr_rec:.4f}, SSIM_rec: {ssim_rec:.4f}')
128-
129-
# => Quantitative evaluation
130-
# validate(net, lens, epoch, args, val_loader)
131-
132-
net.train()
133110

134111
# ==> Train 1 epoch
135112
for img_org in tqdm(train_loader):
@@ -163,21 +140,54 @@ def end2end_train(lens, net, args):
163140
lens_optim.step()
164141

165142
if not args['DEBUG']:
166-
wandb.log({"loss_class":L_rec.detach().item()})
143+
wandb.log({"loss_class": L_rec.detach().item()})
167144

168145
net_sche.step()
169146
lens_sche.step()
170147

171148
logging.info(f'Epoch{epoch+1} finishs.')
172149

173150

151+
# ==> Evaluate
152+
if epoch % 1 == 0:
153+
net.eval()
154+
with torch.no_grad():
155+
# => Save data and simple evaluation
156+
lens.write_lens_json(f'{result_dir}/epoch{epoch}.json')
157+
lens.analysis(f'{result_dir}/epoch{epoch}', render=False, zmx_format=True)
158+
159+
torch.save(net.state_dict(), f'{result_dir}/net_epoch{epoch}.pth')
160+
161+
# => Qualitative evaluation
162+
img1 = cv.cvtColor(cv.imread(f'./datasets/cat.png'), cv.COLOR_BGR2RGB)
163+
img1 = cv.resize(img1, args['train']['img_res']).astype(np.float32)
164+
img1 = torch.from_numpy(img1/255.).permute(2, 0, 1).unsqueeze(0).to(device)
165+
img1 = normalize_ImageNet(img1)
166+
167+
img1_render = lens.render(img1)
168+
psnr_render = batch_PSNR(img1, img1_render)
169+
ssim_render = batch_SSIM(img1, img1_render)
170+
save_image(denormalize_ImageNet(img1_render), f'{result_dir}/img1_render_epoch{epoch}.png')
171+
img1_rec = net(img1_render)
172+
psnr_rec = batch_PSNR(img1, img1_rec)
173+
ssim_rec = batch_SSIM(img1, img1_rec)
174+
save_image(denormalize_ImageNet(img1_rec), f'{result_dir}/img1_rec_epoch{epoch}.png')
175+
176+
logging.info(f'Epoch [{epoch}/{args["train"]["epochs"]}], PSNR_render: {psnr_render:.4f}, SSIM_render: {ssim_render:.4f}, PSNR_rec: {psnr_rec:.4f}, SSIM_rec: {ssim_rec:.4f}')
177+
178+
# => Quantitative evaluation
179+
# validate(net, lens, epoch, args, val_loader)
180+
181+
net.train()
182+
183+
174184
if __name__=='__main__':
175185
args = config()
176186

177187
# ========================================
178188
# Line 1: load a lens
179189
# ========================================
180-
lens = Lensgroup(filename=args['lens']['path'], sensor_res=args['train']['img_res'])
190+
lens = GeoLens(filename=args['lens']['path'], sensor_res=args['train']['img_res'])
181191
net = ImageRestorationNet()
182192
net = net.to(lens.device)
183193
if args['network']['pretrained']:

0 commit comments

Comments
 (0)