1
+ """
2
+ Task-driven lens design for image classification.
3
+
4
+ We design a lens with from scratch with only image-classification loss. This makes sure no prior knowledge or classical lens design objective (spot size, PSF...) is used in the task-driven lens design. By doing this, we can explore "unseen" lens design space to find a lens that is optimal for a task, because we totally get rid of classical lens design!
5
+
6
+ Technical Paper:
7
+ Xinge Yang, Yunfeng Nie, Fu Qiang and Wolfgang Heidrich, "Image Quality Is Not All You Want: Task-Driven Lens Design for Image Classification" Arxiv preprint 2023.
8
+
9
+ This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
10
+ # The license is only for non-commercial use (commercial licenses can be obtained from authors).
11
+ # The material is provided as-is, with no warranties whatsoever.
12
+ # If you publish any code, data, or scientific work based on this, please cite our work.
13
+ """
14
+ import os
15
+ import yaml
16
+ import wandb
17
+ import logging
18
+ import random
19
+ import string
20
+ import timm
21
+ import cv2 as cv
22
+ from tqdm import tqdm
23
+ from datetime import datetime
24
+ from transformers import get_cosine_schedule_with_warmup
25
+
26
+ import torch
27
+ from torchvision .utils import save_image
28
+ from torch .utils .data import DataLoader
29
+ import torch .nn as nn
30
+ import torchvision .transforms as transforms
31
+ from torchvision .datasets import ImageFolder
32
+
33
+ from deeplens import GeoLens
34
+ from deeplens .utils import *
35
+ from deeplens .optics .basics import *
36
+ from deeplens .network .dataset import ImageDataset
37
+ from deeplens .optics .render_psf import render_psf
38
+
39
+ def config ():
40
+ # ==> Config
41
+ with open ('configs/4_tasklens.yml' ) as f :
42
+ args = yaml .load (f , Loader = yaml .FullLoader )
43
+
44
+ # ==> Result folder
45
+ characters = string .ascii_letters + string .digits
46
+ random_string = '' .join (random .choice (characters ) for i in range (4 ))
47
+ result_dir = f'./results/' + datetime .now ().strftime ("%m%d-%H%M%S" ) + '-TaskLens' + '-' + random_string
48
+ args ['result_dir' ] = result_dir
49
+ os .makedirs (result_dir , exist_ok = True )
50
+ print (f'Result folder: { result_dir } ' )
51
+
52
+ if args ['seed' ] is None :
53
+ seed = random .randint (0 , 100 )
54
+ args ['seed' ] = seed
55
+ set_seed (args ['seed' ])
56
+
57
+ # ==> Log
58
+ set_logger (result_dir )
59
+ if not args ['DEBUG' ]:
60
+ raise Exception ('Add your wandb logging config here.' )
61
+
62
+ # ==> Device
63
+ num_gpus = torch .cuda .device_count ()
64
+ args ['num_gpus' ] = num_gpus
65
+ device = torch .device (f"cuda" if torch .cuda .is_available () else "cpu" )
66
+ args ['device' ] = device
67
+ logging .info (f'Using { num_gpus } { torch .cuda .get_device_name (0 )} GPU(s)' )
68
+
69
+ # ==> Save config
70
+ with open (f'{ result_dir } /config.yml' , 'w' ) as f :
71
+ yaml .dump (args , f )
72
+
73
+ with open (f'{ result_dir } /4_tasklens_design.py' , 'w' ) as f :
74
+ with open ('4_tasklens_design.py' , 'r' ) as code :
75
+ f .write (code .read ())
76
+
77
+ return args
78
+
79
+
80
+ def get_dataset (args ):
81
+ dataset = args ['train' ]['dataset' ]
82
+ img_res = args ['train' ]['img_res' ]
83
+ bs = args ['train' ]['bs' ]
84
+
85
+ # ==> Transforms
86
+ train_transform = transforms .Compose ([
87
+ transforms .Resize (img_res ),
88
+ transforms .RandomHorizontalFlip (),
89
+ transforms .TrivialAugmentWide (),
90
+ transforms .ToTensor (),
91
+ transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 ))
92
+ ])
93
+
94
+ val_transform = transforms .Compose ([
95
+ transforms .Resize (img_res ),
96
+ transforms .ToTensor (),
97
+ transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 ))
98
+ ])
99
+
100
+ # ==> Datset
101
+ if dataset == 'imagenet' :
102
+ train_dataset = ImageFolder (root = args ['imagenet_train_dir' ], transform = train_transform )
103
+ val_dataset = ImageFolder (root = args ['imagenet_val_dir' ], transform = val_transform )
104
+ elif dataset == 'imagenet_local' :
105
+ train_dataset = ImageFolder (root = args ['imagenet_train_dir_local' ], transform = train_transform )
106
+ val_dataset = ImageFolder (root = args ['imagenet_val_dir_local' ], transform = val_transform )
107
+ else :
108
+ raise NotImplementedError
109
+
110
+ # ==> Data loader
111
+ train_loader = torch .utils .data .DataLoader (train_dataset , batch_size = bs , shuffle = True )
112
+ val_loader = torch .utils .data .DataLoader (val_dataset , batch_size = bs , shuffle = False )
113
+
114
+ return train_loader , val_loader
115
+
116
+
117
+ def get_network (args ):
118
+ if args ['network' ]['model' ] == 'resnet50' :
119
+ net = timm .create_model (
120
+ 'resnet50' ,
121
+ pretrained = True ,
122
+ num_classes = 1000
123
+ )
124
+ elif args ['network' ]['model' ] == 'swin_transformer' :
125
+ net = timm .create_model (
126
+ 'swin_base_patch4_window7_224_in22k' ,
127
+ pretrained = True ,
128
+ num_classes = 1000
129
+ )
130
+ elif args ['network' ]['model' ] == 'mobilenet' :
131
+ net = timm .create_model (
132
+ 'mobilenetv3_large_100' ,
133
+ pretrained = True ,
134
+ num_classes = 1000
135
+ )
136
+ elif args ['network' ]['model' ] == 'vit' :
137
+ net = timm .create_model (
138
+ 'vit_large_patch16_224_in21k' ,
139
+ pretrained = True ,
140
+ num_classes = 1000
141
+ )
142
+ else :
143
+ raise NotImplementedError
144
+
145
+ # Parallel
146
+ net = nn .DataParallel (net , device_ids = range (args ['num_gpus' ]))
147
+ return net
148
+
149
+
150
+ @torch .no_grad ()
151
+ def validate (lens , net , epoch , args , val_loader ):
152
+ """ Test image classification accuracy.
153
+ """
154
+ # Parameters
155
+ device = args ['device' ]
156
+ result_dir = args ['result_dir' ]
157
+ depth = args ['train' ]['depth' ]
158
+ bs = args ['train' ]['bs' ]
159
+ ks = args ['train' ]['psf_ks' ]
160
+ psf_grid = args ['train' ]['psf_grid' ]
161
+ points = lens .point_source_grid (depth = depth , grid = psf_grid * 2 - 1 , quater = True ).reshape (- 1 , 3 )
162
+
163
+ # Scores
164
+ correct = 0.0
165
+ total = 0.0
166
+
167
+ # Calculate PSFs
168
+ psf = lens .psf_rgb (points = points , ks = ks , spp = 4096 )
169
+
170
+ # Loop over the validation set in batches
171
+ for _ , (img_org , labels ) in tqdm (enumerate (val_loader )):
172
+ if img_org .shape [0 ] != bs :
173
+ continue
174
+
175
+ # Get images and labels
176
+ img_org = img_org .to (device )
177
+ labels = labels .to (device )
178
+
179
+ # Render image with PSF map
180
+ img_render = render_psf (img_org , psf )
181
+ img_render = torch .cat (img_render )
182
+ labels = labels .repeat (psf_grid ** 2 )
183
+
184
+ # Forward pass and prediction
185
+ outputs = net (img_render )
186
+ _ , predicted = torch .max (outputs .data , 1 )
187
+
188
+ # Update accuracy statistics
189
+ total += labels .size (0 )
190
+ correct += (predicted == labels ).sum ().item ()
191
+
192
+ # Print validation accuracy
193
+ acc = correct / total
194
+ if acc > args ['val_acc' ]:
195
+ args ['val_acc' ] = acc
196
+ logging .info (f'Best epoch is { epoch } , best Val acc is { acc } .' )
197
+ torch .save (net .state_dict (), f'{ result_dir } /classi_model_best.pth' )
198
+
199
+ logging .info ('Validation Accuracy: {:.2f}%' .format (100 * acc ))
200
+ if not args ['DEBUG' ]:
201
+ wandb .log ({"classi_acc" :acc })
202
+
203
+
204
+ def train (args , lens , net ):
205
+ device = args ['device' ]
206
+ result_dir = args ['result_dir' ]
207
+ bs = args ['train' ]['bs' ]
208
+ ks = args ['train' ]['psf_ks' ]
209
+ psf_grid = args ['train' ]['psf_grid' ]
210
+ spp = args ['train' ]['spp' ]
211
+ depth = args ['train' ]['depth' ]
212
+ lens_lrs = [float (i ) for i in args ['lens' ]['lr' ]]
213
+ args ['val_acc' ] = 0
214
+
215
+ # ==> Dataset
216
+ train_loader , val_loader = get_dataset (args )
217
+ batchs = len (train_loader )
218
+ epochs = args ['train' ]['epochs' ]
219
+
220
+ # ==> Optimizer and scheduler
221
+ lens_optim = lens .get_optimizer (lr = lens_lrs )
222
+ lens_sche = get_cosine_schedule_with_warmup (lens_optim , num_warmup_steps = 500 , num_training_steps = batchs * epochs )
223
+ # # Uncomment for End-to-End lens-network co-design
224
+ # net_optim = torch.optim.Adam(net.parameters(), lr=1e-4)
225
+ # net_sche = get_cosine_schedule_with_warmup(net_optim, num_warmup_steps=500, num_training_steps=batchs*epochs)
226
+
227
+ # ==> Loss
228
+ cri_classi = nn .CrossEntropyLoss ()
229
+
230
+ # ==> Training
231
+ logging .info (f'==> Start training.' )
232
+ points = lens .point_source_grid (depth = depth , grid = psf_grid , quater = True ).reshape (- 1 , 3 )
233
+ for epoch in range (args ['train' ]['epochs' ] + 1 ):
234
+
235
+ # =============================
236
+ # Evaluation
237
+ # =============================
238
+ if epoch % 1 == 0 and epoch > 0 :
239
+ net .eval ()
240
+ lens .correct_shape ()
241
+ lens .write_lens_json (f'{ result_dir } /epoch{ epoch } .json' )
242
+ lens .analysis (f'{ result_dir } /epoch{ epoch } ' , render = False )
243
+ validate (lens , net , epoch , args , val_loader )
244
+
245
+ # =============================
246
+ # Training
247
+ # =============================
248
+ net .train ()
249
+
250
+ # ==> Task-driven lens design: a well-trained network serves as lens design objective
251
+ for ii , (img_org , labels ) in tqdm (enumerate (train_loader )):
252
+
253
+ # Continue is wrong batch size
254
+ if img_org .shape [0 ] != bs :
255
+ continue
256
+
257
+ # Get images and labels
258
+ img_org = img_org .to (device )
259
+ labels = labels .to (device )
260
+
261
+ # Option 1: Render image with PSF map
262
+ psf = lens .psf_rgb (points = points , ks = ks , center = False , spp = spp ) # [N, 3, ks, ks]
263
+ img_render = []
264
+ for psf_idx in range (psf .shape [0 ]):
265
+ img_render .append (render_psf (img_org , psf [psf_idx , ...]))
266
+ img_render = torch .cat (img_render ) # [N * B, 3, sensor_res, sensor_res]
267
+ labels = labels .repeat (psf .shape [0 ])
268
+
269
+ # Option 2: Render image with ray tracing
270
+ # img_render = lens.render(img_org)
271
+
272
+ # Image classification
273
+ labels_pred = net (img_render )
274
+
275
+ # Loss
276
+ L_classi = cri_classi (labels_pred , labels )
277
+ L_reg = lens .loss_self_intersec () #+ lens.loss_ray_angle()
278
+
279
+ L = L_classi + 0.02 * L_reg
280
+
281
+ # Update
282
+ lens_optim .zero_grad ()
283
+ # net_optim.zero_grad()
284
+ L .backward ()
285
+ lens_optim .step ()
286
+ # net_optim.step()
287
+ lens_sche .step ()
288
+ # net_sche.step()
289
+
290
+ if not args ['DEBUG' ]:
291
+ wandb .log ({"loss_class" : L_classi .detach ().item ()})
292
+
293
+ # Print statistics every 1000 batches
294
+ if ii % 100 == 0 and ii > 0 :
295
+ logging .info ('Epoch [{}/{}], Batch [{}/{}], Loss: {:.4f}' .format (epoch + 1 , args ['train' ]['epochs' ], ii , len (train_loader ), L .item ()))
296
+ lens .correct_shape ()
297
+ lens .write_lens_json (f'{ result_dir } /epoch{ epoch } _batch{ ii } .json' )
298
+ lens .analysis (f'{ result_dir } /epoch{ epoch } _batch{ ii } ' , render = False )
299
+
300
+ logging .info (f'Epoch{ epoch + 1 } finishs.' )
301
+
302
+
303
+ if __name__ == '__main__' :
304
+ args = config ()
305
+
306
+ # Lens
307
+ lens = GeoLens (filename = args ['lens' ]['path' ], sensor_res = args ['lens' ]['sensor_res' ], device = args ['device' ])
308
+ lens .set_target_fov_fnum (hfov = args ['lens' ]['target_hfov' ], fnum = args ['lens' ]['target_fnum' ])
309
+ lens .write_lens_json (f'{ args ["result_dir" ]} /epoch0.json' )
310
+ lens .analysis (f'{ args ["result_dir" ]} /epoch0' , render = False , zmx_format = True )
311
+
312
+ # Network
313
+ net = get_network (args )
314
+ for param in net .parameters ():
315
+ param .requires_grad = False
316
+ net = net .to (args ['device' ])
317
+
318
+ # End-to-end lens-network co-design
319
+ train (args , lens , net )
0 commit comments