2
2
End2End optical design with only 5 lines of code.
3
3
4
4
Technical Paper:
5
- Yang, Xinge and Fu , Qiang and Heidrich, Wolfgang, "Curriculum learning for ab initio deep learned refractive optics," ArXiv preprint ( 2023)
5
+ Xinge Yang , Qiang Fu and Wolfgang Heidrich, "Curriculum learning for ab initio deep learned refractive optics," ArXiv preprint 2023.
6
6
7
7
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
8
8
# The license is only for non-commercial use (commercial licenses can be obtained from authors).
15
15
import logging
16
16
import random
17
17
import string
18
- import argparse
19
18
import numpy as np
20
19
import cv2 as cv
21
20
from tqdm import tqdm
30
29
31
30
def config ():
32
31
# ==> Config
33
- with open ('configs/end2end_5lines .yml' ) as f :
32
+ with open ('configs/1_end2end_5lines .yml' ) as f :
34
33
args = yaml .load (f , Loader = yaml .FullLoader )
35
-
34
+
36
35
# ==> Result folder
37
36
characters = string .ascii_letters + string .digits
38
37
random_string = '' .join (random .choice (characters ) for i in range (4 ))
39
- result_dir = f'./results/' + datetime .now ().strftime ("%m%d-%H%M%S" ) + '-End2End' + '-' + random_string
40
- args ['result_dir' ] = result_dir
38
+ current_time = datetime .now ().strftime ("%m%d-%H%M%S" )
39
+ exp_name = current_time + '-End2End-5-lines-' + random_string
40
+ result_dir = f'./results/{ exp_name } '
41
41
os .makedirs (result_dir , exist_ok = True )
42
- print ( f'Result folder: { result_dir } ' )
42
+ args [ 'result_dir' ] = result_dir
43
43
44
- # ==> Random seed
45
- set_seed (args ['train' ]['seed' ])
44
+ if args ['seed' ] is None :
45
+ seed = random .randint (0 , 100 )
46
+ args ['seed' ] = seed
47
+ set_seed (args ['seed' ])
46
48
47
- # ==> Logger
49
+ # ==> Log
48
50
set_logger (result_dir )
49
- # Log to wandb
51
+ logging . info ( f'EXP: { args [ "EXP_NAME" ] } ' )
50
52
if not args ['DEBUG' ]:
51
- pass
53
+ raise Exception ( 'Add your wandb logging config here.' )
52
54
53
55
# ==> Device
54
56
num_gpus = torch .cuda .device_count ()
55
57
args ['num_gpus' ] = num_gpus
56
58
device = torch .device (f"cuda" if torch .cuda .is_available () else "cpu" )
57
- logging .info (f'Using { num_gpus } { torch .cuda .get_device_name (0 )} GPU(s)' )
58
59
args ['device' ] = device
60
+ logging .info (f'Using { num_gpus } { torch .cuda .get_device_name (0 )} GPU(s)' )
59
61
60
- # ==> Save config
62
+ # ==> Save config and original code
61
63
with open (f'{ result_dir } /config.yml' , 'w' ) as f :
62
64
yaml .dump (args , f )
63
65
66
+ with open (f'{ result_dir } /1_end2end_5lines.py' , 'w' ) as f :
67
+ with open ('1_end2end_5lines.py' , 'r' ) as code :
68
+ f .write (code .read ())
69
+
64
70
return args
65
71
66
72
@@ -80,7 +86,6 @@ def end2end_train(lens, net, args):
80
86
# ==> Network optimizer
81
87
batchs = len (train_loader )
82
88
epochs = args ['train' ]['epochs' ]
83
- warm_up = args ['train' ]['warm_up' ]
84
89
net_optim = torch .optim .AdamW (net .parameters (), lr = args ['network' ]['lr' ], betas = (0.9 , 0.98 ), eps = 1e-08 )
85
90
net_sche = torch .optim .lr_scheduler .CosineAnnealingLR (net_optim , T_max = epochs * batchs , eta_min = 0 , last_epoch = - 1 )
86
91
@@ -93,43 +98,15 @@ def end2end_train(lens, net, args):
93
98
lens_sche = torch .optim .lr_scheduler .CosineAnnealingLR (lens_optim , T_max = epochs * batchs , eta_min = 0 , last_epoch = - 1 )
94
99
95
100
# ==> Criterion
96
- cri_l2 = nn .MSELoss ()
101
+ cri_l2 = nn .L1Loss ()
97
102
98
- # ==> Training
103
+ # ==> Log
99
104
logging .info (f'Start End2End optical design.' )
105
+ lens .write_lens_json (f'{ result_dir } /epoch0.json' )
106
+ lens .analysis (f'{ result_dir } /epoch0' , render = False , zmx_format = True )
107
+
108
+ # ==> Training
100
109
for epoch in range (args ['train' ]['epochs' ] + 1 ):
101
-
102
- # ==> Evaluate
103
- if epoch % 1 == 0 :
104
- net .eval ()
105
- with torch .no_grad ():
106
- # => Save data and simple evaluation
107
- lens .write_lens_json (f'{ result_dir } /epoch{ epoch } .json' )
108
- lens .analysis (f'{ result_dir } /epoch{ epoch } ' , render = False , zmx_format = True )
109
-
110
- torch .save (net .state_dict (), f'{ result_dir } /net_epoch{ epoch } .pth' )
111
-
112
- # => Qualitative evaluation
113
- img1 = cv .cvtColor (cv .imread (f'./datasets/lena.png' ), cv .COLOR_BGR2RGB )
114
- img1 = cv .resize (img1 , args ['train' ]['img_res' ]).astype (np .float32 )
115
- img1 = torch .from_numpy (img1 / 255. ).permute (2 , 0 , 1 ).unsqueeze (0 ).to (device )
116
- img1 = normalize_ImageNet_stats (img1 )
117
-
118
- img1_render = lens .render (img1 )
119
- psnr_render = batch_PSNR (img1 , img1_render )
120
- ssim_render = batch_SSIM (img1 , img1_render )
121
- save_image (de_normalize (img1_render ), f'{ result_dir } /img1_render_epoch{ epoch } .png' )
122
- img1_rec = net (img1_render )
123
- psnr_rec = batch_PSNR (img1 , img1_rec )
124
- ssim_rec = batch_SSIM (img1 , img1_rec )
125
- save_image (de_normalize (img1_rec ), f'{ result_dir } /img1_rec_epoch{ epoch } .png' )
126
-
127
- logging .info (f'Epoch [{ epoch } /{ args ["train" ]["epochs" ]} ], PSNR_render: { psnr_render :.4f} , SSIM_render: { ssim_render :.4f} , PSNR_rec: { psnr_rec :.4f} , SSIM_rec: { ssim_rec :.4f} ' )
128
-
129
- # => Quantitative evaluation
130
- # validate(net, lens, epoch, args, val_loader)
131
-
132
- net .train ()
133
110
134
111
# ==> Train 1 epoch
135
112
for img_org in tqdm (train_loader ):
@@ -163,21 +140,54 @@ def end2end_train(lens, net, args):
163
140
lens_optim .step ()
164
141
165
142
if not args ['DEBUG' ]:
166
- wandb .log ({"loss_class" :L_rec .detach ().item ()})
143
+ wandb .log ({"loss_class" : L_rec .detach ().item ()})
167
144
168
145
net_sche .step ()
169
146
lens_sche .step ()
170
147
171
148
logging .info (f'Epoch{ epoch + 1 } finishs.' )
172
149
173
150
151
+ # ==> Evaluate
152
+ if epoch % 1 == 0 :
153
+ net .eval ()
154
+ with torch .no_grad ():
155
+ # => Save data and simple evaluation
156
+ lens .write_lens_json (f'{ result_dir } /epoch{ epoch } .json' )
157
+ lens .analysis (f'{ result_dir } /epoch{ epoch } ' , render = False , zmx_format = True )
158
+
159
+ torch .save (net .state_dict (), f'{ result_dir } /net_epoch{ epoch } .pth' )
160
+
161
+ # => Qualitative evaluation
162
+ img1 = cv .cvtColor (cv .imread (f'./datasets/cat.png' ), cv .COLOR_BGR2RGB )
163
+ img1 = cv .resize (img1 , args ['train' ]['img_res' ]).astype (np .float32 )
164
+ img1 = torch .from_numpy (img1 / 255. ).permute (2 , 0 , 1 ).unsqueeze (0 ).to (device )
165
+ img1 = normalize_ImageNet (img1 )
166
+
167
+ img1_render = lens .render (img1 )
168
+ psnr_render = batch_PSNR (img1 , img1_render )
169
+ ssim_render = batch_SSIM (img1 , img1_render )
170
+ save_image (denormalize_ImageNet (img1_render ), f'{ result_dir } /img1_render_epoch{ epoch } .png' )
171
+ img1_rec = net (img1_render )
172
+ psnr_rec = batch_PSNR (img1 , img1_rec )
173
+ ssim_rec = batch_SSIM (img1 , img1_rec )
174
+ save_image (denormalize_ImageNet (img1_rec ), f'{ result_dir } /img1_rec_epoch{ epoch } .png' )
175
+
176
+ logging .info (f'Epoch [{ epoch } /{ args ["train" ]["epochs" ]} ], PSNR_render: { psnr_render :.4f} , SSIM_render: { ssim_render :.4f} , PSNR_rec: { psnr_rec :.4f} , SSIM_rec: { ssim_rec :.4f} ' )
177
+
178
+ # => Quantitative evaluation
179
+ # validate(net, lens, epoch, args, val_loader)
180
+
181
+ net .train ()
182
+
183
+
174
184
if __name__ == '__main__' :
175
185
args = config ()
176
186
177
187
# ========================================
178
188
# Line 1: load a lens
179
189
# ========================================
180
- lens = Lensgroup (filename = args ['lens' ]['path' ], sensor_res = args ['train' ]['img_res' ])
190
+ lens = GeoLens (filename = args ['lens' ]['path' ], sensor_res = args ['train' ]['img_res' ])
181
191
net = ImageRestorationNet ()
182
192
net = net .to (lens .device )
183
193
if args ['network' ]['pretrained' ]:
0 commit comments