Skip to content

Commit 4b7c5d8

Browse files
committed
[new] lens glasss, zmx io
1 parent 3b62d62 commit 4b7c5d8

File tree

9 files changed

+559
-221
lines changed

9 files changed

+559
-221
lines changed

0_hello_deeplens.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
def main():
1818
lens = GeoLens(filename='./lenses/camera/ef35mm_f2.0.json')
1919
# lens = GeoLens(filename='./lenses/cellphone/cellphone80deg.json')
20+
# lens = GeoLens(filename='./lenses/zemax_double_gaussian.zmx')
2021
lens.analysis(render=True)
2122

2223
if __name__=='__main__':

1_end2end_5lines.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import yaml
1414
import wandb
15+
import shutil
1516
import logging
1617
import random
1718
import string
@@ -63,9 +64,7 @@ def config():
6364
with open(f'{result_dir}/config.yml', 'w') as f:
6465
yaml.dump(args, f)
6566

66-
with open(f'{result_dir}/1_end2end_5lines.py', 'w') as f:
67-
with open('1_end2end_5lines.py', 'r') as code:
68-
f.write(code.read())
67+
shutil.copy('1_end2end_5lines.py', f'{result_dir}/1_end2end_5lines.py')
6968

7069
return args
7170

2_autolens_rms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import yaml
1717
import random
1818
import string
19-
20-
2119
from datetime import datetime
2220
from tqdm import tqdm
2321
from transformers import get_cosine_schedule_with_warmup
@@ -35,7 +33,7 @@ def config():
3533
characters = string.ascii_letters + string.digits
3634
random_string = ''.join(random.choice(characters) for i in range(4))
3735
current_time = datetime.now().strftime("%m%d-%H%M%S")
38-
exp_name = current_time + '-Auto-Lens-Design-' + random_string
36+
exp_name = current_time + '-AutoLens-RMS-' + random_string
3937
result_dir = f'./results/{exp_name}'
4038
os.makedirs(result_dir, exist_ok=True)
4139
args['result_dir'] = result_dir

4_tasklens_img_classi.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
"""
2+
Task-driven lens design for image classification.
3+
4+
We design a lens with from scratch with only image-classification loss. This makes sure no prior knowledge or classical lens design objective (spot size, PSF...) is used in the task-driven lens design. By doing this, we can explore "unseen" lens design space to find a lens that is optimal for a task, because we totally get rid of classical lens design!
5+
6+
Technical Paper:
7+
Xinge Yang, Yunfeng Nie, Fu Qiang and Wolfgang Heidrich, "Image Quality Is Not All You Want: Task-Driven Lens Design for Image Classification" Arxiv preprint 2023.
8+
9+
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
10+
# The license is only for non-commercial use (commercial licenses can be obtained from authors).
11+
# The material is provided as-is, with no warranties whatsoever.
12+
# If you publish any code, data, or scientific work based on this, please cite our work.
13+
"""
14+
import os
15+
import yaml
16+
import wandb
17+
import logging
18+
import random
19+
import string
20+
import timm
21+
import cv2 as cv
22+
from tqdm import tqdm
23+
from datetime import datetime
24+
from transformers import get_cosine_schedule_with_warmup
25+
26+
import torch
27+
from torchvision.utils import save_image
28+
from torch.utils.data import DataLoader
29+
import torch.nn as nn
30+
import torchvision.transforms as transforms
31+
from torchvision.datasets import ImageFolder
32+
33+
from deeplens import GeoLens
34+
from deeplens.utils import *
35+
from deeplens.optics.basics import *
36+
from deeplens.network.dataset import ImageDataset
37+
from deeplens.optics.render_psf import render_psf
38+
39+
def config():
40+
# ==> Config
41+
with open('configs/4_tasklens.yml') as f:
42+
args = yaml.load(f, Loader=yaml.FullLoader)
43+
44+
# ==> Result folder
45+
characters = string.ascii_letters + string.digits
46+
random_string = ''.join(random.choice(characters) for i in range(4))
47+
result_dir = f'./results/' + datetime.now().strftime("%m%d-%H%M%S") + '-TaskLens' + '-' + random_string
48+
args['result_dir'] = result_dir
49+
os.makedirs(result_dir, exist_ok=True)
50+
print(f'Result folder: {result_dir}')
51+
52+
if args['seed'] is None:
53+
seed = random.randint(0, 100)
54+
args['seed'] = seed
55+
set_seed(args['seed'])
56+
57+
# ==> Log
58+
set_logger(result_dir)
59+
if not args['DEBUG']:
60+
raise Exception('Add your wandb logging config here.')
61+
62+
# ==> Device
63+
num_gpus = torch.cuda.device_count()
64+
args['num_gpus'] = num_gpus
65+
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
66+
args['device'] = device
67+
logging.info(f'Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)')
68+
69+
# ==> Save config
70+
with open(f'{result_dir}/config.yml', 'w') as f:
71+
yaml.dump(args, f)
72+
73+
with open(f'{result_dir}/4_tasklens_design.py', 'w') as f:
74+
with open('4_tasklens_design.py', 'r') as code:
75+
f.write(code.read())
76+
77+
return args
78+
79+
80+
def get_dataset(args):
81+
dataset = args['train']['dataset']
82+
img_res = args['train']['img_res']
83+
bs = args['train']['bs']
84+
85+
# ==> Transforms
86+
train_transform = transforms.Compose([
87+
transforms.Resize(img_res),
88+
transforms.RandomHorizontalFlip(),
89+
transforms.TrivialAugmentWide(),
90+
transforms.ToTensor(),
91+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
92+
])
93+
94+
val_transform = transforms.Compose([
95+
transforms.Resize(img_res),
96+
transforms.ToTensor(),
97+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
98+
])
99+
100+
# ==> Datset
101+
if dataset == 'imagenet':
102+
train_dataset = ImageFolder(root=args['imagenet_train_dir'], transform=train_transform)
103+
val_dataset = ImageFolder(root=args['imagenet_val_dir'], transform=val_transform)
104+
elif dataset == 'imagenet_local':
105+
train_dataset = ImageFolder(root=args['imagenet_train_dir_local'], transform=train_transform)
106+
val_dataset = ImageFolder(root=args['imagenet_val_dir_local'], transform=val_transform)
107+
else:
108+
raise NotImplementedError
109+
110+
# ==> Data loader
111+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
112+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, shuffle=False)
113+
114+
return train_loader, val_loader
115+
116+
117+
def get_network(args):
118+
if args['network']['model'] == 'resnet50':
119+
net = timm.create_model(
120+
'resnet50',
121+
pretrained=True,
122+
num_classes=1000
123+
)
124+
elif args['network']['model'] == 'swin_transformer':
125+
net = timm.create_model(
126+
'swin_base_patch4_window7_224_in22k',
127+
pretrained=True,
128+
num_classes=1000
129+
)
130+
elif args['network']['model'] == 'mobilenet':
131+
net = timm.create_model(
132+
'mobilenetv3_large_100',
133+
pretrained=True,
134+
num_classes=1000
135+
)
136+
elif args['network']['model'] == 'vit':
137+
net = timm.create_model(
138+
'vit_large_patch16_224_in21k',
139+
pretrained=True,
140+
num_classes=1000
141+
)
142+
else:
143+
raise NotImplementedError
144+
145+
# Parallel
146+
net = nn.DataParallel(net, device_ids=range(args['num_gpus']))
147+
return net
148+
149+
150+
@torch.no_grad()
151+
def validate(lens, net, epoch, args, val_loader):
152+
""" Test image classification accuracy.
153+
"""
154+
# Parameters
155+
device = args['device']
156+
result_dir = args['result_dir']
157+
depth = args['train']['depth']
158+
bs = args['train']['bs']
159+
ks = args['train']['psf_ks']
160+
psf_grid = args['train']['psf_grid']
161+
points = lens.point_source_grid(depth=depth, grid=psf_grid*2-1, quater=True).reshape(-1, 3)
162+
163+
# Scores
164+
correct = 0.0
165+
total = 0.0
166+
167+
# Calculate PSFs
168+
psf = lens.psf_rgb(points=points, ks=ks, spp=4096)
169+
170+
# Loop over the validation set in batches
171+
for _, (img_org, labels) in tqdm(enumerate(val_loader)):
172+
if img_org.shape[0] != bs:
173+
continue
174+
175+
# Get images and labels
176+
img_org = img_org.to(device)
177+
labels = labels.to(device)
178+
179+
# Render image with PSF map
180+
img_render = render_psf(img_org, psf)
181+
img_render = torch.cat(img_render)
182+
labels = labels.repeat(psf_grid**2)
183+
184+
# Forward pass and prediction
185+
outputs = net(img_render)
186+
_, predicted = torch.max(outputs.data, 1)
187+
188+
# Update accuracy statistics
189+
total += labels.size(0)
190+
correct += (predicted == labels).sum().item()
191+
192+
# Print validation accuracy
193+
acc = correct / total
194+
if acc > args['val_acc']:
195+
args['val_acc'] = acc
196+
logging.info(f'Best epoch is {epoch}, best Val acc is {acc}.')
197+
torch.save(net.state_dict(), f'{result_dir}/classi_model_best.pth')
198+
199+
logging.info('Validation Accuracy: {:.2f}%'.format(100 * acc))
200+
if not args['DEBUG']:
201+
wandb.log({"classi_acc":acc})
202+
203+
204+
def train(args, lens, net):
205+
device = args['device']
206+
result_dir = args['result_dir']
207+
bs = args['train']['bs']
208+
ks = args['train']['psf_ks']
209+
psf_grid = args['train']['psf_grid']
210+
spp = args['train']['spp']
211+
depth = args['train']['depth']
212+
lens_lrs = [float(i) for i in args['lens']['lr']]
213+
args['val_acc'] = 0
214+
215+
# ==> Dataset
216+
train_loader, val_loader = get_dataset(args)
217+
batchs = len(train_loader)
218+
epochs = args['train']['epochs']
219+
220+
# ==> Optimizer and scheduler
221+
lens_optim = lens.get_optimizer(lr=lens_lrs)
222+
lens_sche = get_cosine_schedule_with_warmup(lens_optim, num_warmup_steps=500, num_training_steps=batchs*epochs)
223+
# # Uncomment for End-to-End lens-network co-design
224+
# net_optim = torch.optim.Adam(net.parameters(), lr=1e-4)
225+
# net_sche = get_cosine_schedule_with_warmup(net_optim, num_warmup_steps=500, num_training_steps=batchs*epochs)
226+
227+
# ==> Loss
228+
cri_classi = nn.CrossEntropyLoss()
229+
230+
# ==> Training
231+
logging.info(f'==> Start training.')
232+
points = lens.point_source_grid(depth=depth, grid=psf_grid, quater=True).reshape(-1, 3)
233+
for epoch in range(args['train']['epochs'] + 1):
234+
235+
# =============================
236+
# Evaluation
237+
# =============================
238+
if epoch % 1 == 0 and epoch > 0:
239+
net.eval()
240+
lens.correct_shape()
241+
lens.write_lens_json(f'{result_dir}/epoch{epoch}.json')
242+
lens.analysis(f'{result_dir}/epoch{epoch}', render=False)
243+
validate(lens, net, epoch, args, val_loader)
244+
245+
# =============================
246+
# Training
247+
# =============================
248+
net.train()
249+
250+
# ==> Task-driven lens design: a well-trained network serves as lens design objective
251+
for ii, (img_org, labels) in tqdm(enumerate(train_loader)):
252+
253+
# Continue is wrong batch size
254+
if img_org.shape[0] != bs:
255+
continue
256+
257+
# Get images and labels
258+
img_org = img_org.to(device)
259+
labels = labels.to(device)
260+
261+
# Option 1: Render image with PSF map
262+
psf = lens.psf_rgb(points=points, ks=ks, center=False, spp=spp) # [N, 3, ks, ks]
263+
img_render = []
264+
for psf_idx in range(psf.shape[0]):
265+
img_render.append(render_psf(img_org, psf[psf_idx, ...]))
266+
img_render = torch.cat(img_render) # [N * B, 3, sensor_res, sensor_res]
267+
labels = labels.repeat(psf.shape[0])
268+
269+
# Option 2: Render image with ray tracing
270+
# img_render = lens.render(img_org)
271+
272+
# Image classification
273+
labels_pred = net(img_render)
274+
275+
# Loss
276+
L_classi = cri_classi(labels_pred, labels)
277+
L_reg = lens.loss_self_intersec() #+ lens.loss_ray_angle()
278+
279+
L = L_classi + 0.02 * L_reg
280+
281+
# Update
282+
lens_optim.zero_grad()
283+
# net_optim.zero_grad()
284+
L.backward()
285+
lens_optim.step()
286+
# net_optim.step()
287+
lens_sche.step()
288+
# net_sche.step()
289+
290+
if not args['DEBUG']:
291+
wandb.log({"loss_class": L_classi.detach().item()})
292+
293+
# Print statistics every 1000 batches
294+
if ii % 100 == 0 and ii > 0:
295+
logging.info('Epoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'.format(epoch+1, args['train']['epochs'], ii, len(train_loader), L.item()))
296+
lens.correct_shape()
297+
lens.write_lens_json(f'{result_dir}/epoch{epoch}_batch{ii}.json')
298+
lens.analysis(f'{result_dir}/epoch{epoch}_batch{ii}', render=False)
299+
300+
logging.info(f'Epoch{epoch+1} finishs.')
301+
302+
303+
if __name__=='__main__':
304+
args = config()
305+
306+
# Lens
307+
lens = GeoLens(filename=args['lens']['path'], sensor_res=args['lens']['sensor_res'], device=args['device'])
308+
lens.set_target_fov_fnum(hfov=args['lens']['target_hfov'], fnum=args['lens']['target_fnum'])
309+
lens.write_lens_json(f'{args["result_dir"]}/epoch0.json')
310+
lens.analysis(f'{args["result_dir"]}/epoch0', render=False, zmx_format=True)
311+
312+
# Network
313+
net = get_network(args)
314+
for param in net.parameters():
315+
param.requires_grad = False
316+
net = net.to(args['device'])
317+
318+
# End-to-end lens-network co-design
319+
train(args, lens, net)

0 commit comments

Comments
 (0)