Skip to content

Commit 3286452

Browse files
committed
[new] hybrid lens and other important updates
1 parent 56ff3eb commit 3286452

File tree

14 files changed

+1388
-498
lines changed

14 files changed

+1388
-498
lines changed

6_hybridlens_design.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
Jointly optimize refractive-diffractive lens with a differentiable ray-wave model. This code can be easily extended to end-to-end refractive-diffractive lens and network design.
3+
4+
Technical Paper:
5+
Xinge Yang, Matheus Souza, Kunyi Wang, Praneeth Chakravarthula, Qiang Fu and Wolfgang Heidrich, "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model," Siggraph Asia 2024.
6+
7+
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
8+
# The license is only for non-commercial use (commercial licenses can be obtained from authors).
9+
# The material is provided as-is, with no warranties whatsoever.
10+
# If you publish any code, data, or scientific work based on this, please cite our work.
11+
"""
12+
13+
import logging
14+
import os
15+
import random
16+
import string
17+
from datetime import datetime
18+
19+
import torch
20+
import yaml
21+
from torchvision.utils import save_image
22+
from tqdm import tqdm
23+
24+
from deeplens.hybridlens import HybridLens
25+
from deeplens.optics.loss import PSFLoss
26+
from deeplens.utils import set_logger, set_seed
27+
28+
29+
def config():
30+
# ==> Config
31+
args = {"seed": 0, "DEBUG": True}
32+
33+
# ==> Result folder
34+
characters = string.ascii_letters + string.digits
35+
random_string = "".join(random.choice(characters) for i in range(4))
36+
result_dir = (
37+
"./results/"
38+
+ datetime.now().strftime("%m%d-%H%M%S")
39+
+ "-HybridLens"
40+
+ "-"
41+
+ random_string
42+
)
43+
args["result_dir"] = result_dir
44+
os.makedirs(result_dir, exist_ok=True)
45+
print(f"Result folder: {result_dir}")
46+
47+
if args["seed"] is None:
48+
seed = random.randint(0, 100)
49+
args["seed"] = seed
50+
set_seed(args["seed"])
51+
52+
# ==> Log
53+
set_logger(result_dir)
54+
if not args["DEBUG"]:
55+
raise Exception("Add your wandb logging config here.")
56+
57+
# ==> Device
58+
num_gpus = torch.cuda.device_count()
59+
args["num_gpus"] = num_gpus
60+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61+
args["device"] = device
62+
logging.info(f"Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)")
63+
64+
# ==> Save config
65+
with open(f"{result_dir}/config.yml", "w") as f:
66+
yaml.dump(args, f)
67+
68+
with open(f"{result_dir}/6_hybridlens_design.py", "w") as f:
69+
with open("6_hybridlens_design.py", "r") as code:
70+
f.write(code.read())
71+
72+
return args
73+
74+
75+
def main(args):
76+
# Create a hybrid refractive-diffractive lens
77+
lens = HybridLens(filename="./lenses/hybridlens/a489_doe.json")
78+
lens.double()
79+
80+
# PSF optimization loop to focus blue light
81+
optimizer = lens.get_optimizer(doe_lr=0.1, lens_lr=[1e-4, 1e-4, 1e-1, 1e-5])
82+
loss_fn = PSFLoss()
83+
for i in tqdm(range(100 + 1)):
84+
psf = lens.psf(point=[0.0, 0.0, -10000.0], ks=101, wvln=0.489)
85+
86+
optimizer.zero_grad()
87+
loss = loss_fn(psf)
88+
loss.backward()
89+
optimizer.step()
90+
91+
if i % 25 == 0:
92+
lens.write_lens_json(f"{args['result_dir']}/lens_iter{i}.json")
93+
lens.analysis(save_name=f"{args['result_dir']}/lens_iter{i}.png")
94+
save_image(
95+
psf.detach().clone(),
96+
f"{args['result_dir']}/psf_iter{i}.png",
97+
normalize=True,
98+
)
99+
100+
101+
if __name__ == "__main__":
102+
args = config()
103+
main(args)

0 commit comments

Comments
 (0)