Skip to content

Commit 308f0e3

Browse files
committed
[new] fix bugs and new features
1 parent b66e629 commit 308f0e3

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

1_end2end_5lines.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,34 @@ def config():
3333
with open('configs/end2end_5lines.yml') as f:
3434
args = yaml.load(f, Loader=yaml.FullLoader)
3535

36-
# ==> Device
37-
num_gpus = torch.cuda.device_count()
38-
args['num_gpus'] = num_gpus
39-
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
40-
args['device'] = device
41-
4236
# ==> Result folder
4337
characters = string.ascii_letters + string.digits
4438
random_string = ''.join(random.choice(characters) for i in range(4))
4539
result_dir = f'./results/' + datetime.now().strftime("%m%d-%H%M%S") + '-End2End' + '-' + random_string
4640
args['result_dir'] = result_dir
4741
os.makedirs(result_dir, exist_ok=True)
4842
print(f'Result folder: {result_dir}')
49-
50-
# ==> Logger
51-
set_logger(result_dir)
52-
logging.info(args)
5343

5444
# ==> Random seed
5545
set_seed(args['train']['seed'])
5646

47+
# ==> Logger
48+
set_logger(result_dir)
49+
# Log to wandb
5750
if not args['DEBUG']:
58-
# wandb init
5951
pass
60-
52+
53+
# ==> Device
54+
num_gpus = torch.cuda.device_count()
55+
args['num_gpus'] = num_gpus
56+
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
57+
logging.info(f'Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)')
58+
args['device'] = device
59+
60+
# ==> Save config
61+
with open(f'{result_dir}/config.yml', 'w') as f:
62+
yaml.dump(args, f)
63+
6164
return args
6265

6366

deeplens/optics.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def post_computation(self):
124124
avg_pupilz, avg_pupilx = self.entrance_pupil()
125125
self.fnum = self.foclen / avg_pupilx / 2
126126

127+
if self.r_last < 8.0:
128+
self.is_cellphone = True
129+
127130

128131
def find_aperture(self):
129132
""" Find aperture by surfaces previous and next materials.
@@ -820,7 +823,8 @@ def render_compute_image(self, img, depth, scale, ray, point_pixel=True, train=T
820823
irr_img += img[...,idx_i, idx_j+1] * w_i * (1-w_j)
821824
irr_img += img[...,idx_i+1, idx_j+1] * (1-w_i) * (1-w_j)
822825

823-
I = (torch.sum(irr_img * ray.ra, -3) + 1e-9) / (torch.sum(ray.ra, -3) + 1e-6)
826+
I = (torch.sum(irr_img * ray.ra, -3) + 1e-9) / (torch.sum(ray.ra, -3) + 1e-6) # w/ vignetting correction
827+
# I = (torch.sum(irr_img * ray.ra, -3) + 1e-9) / ray.ra.shape[-3] # w/o vignetting correction
824828

825829
# ====> Add sensor noise
826830
if noise > 0:
@@ -1469,7 +1473,7 @@ def set_target_fov_fnum(self, hfov, fnum, imgh=None):
14691473

14701474
self.foclen = self.calc_efl()
14711475
aper_r = self.foclen / fnum / 2
1472-
self.surfaces[self.aper_idx].r = aper_r
1476+
self.surfaces[self.aper_idx].r = float(aper_r)
14731477

14741478

14751479
# ---------------------------
@@ -2216,11 +2220,10 @@ def loss_ray_angle(self, target=0.7, depth=DEPTH):
22162220
def loss_reg(self):
22172221
""" An empirical regularization loss for lens design.
22182222
"""
2219-
# For spherical lens design
2220-
loss_reg = 0.1 * self.loss_infocus() + self.loss_self_intersec(dist_bound=0.5, thickness_bound=0.5)
2221-
2222-
# For cellphone lens design, use 0.01 * loss_reg
2223-
# loss_reg = 0.1 * self.loss_infocus() + self.loss_ray_angle() + (self.loss_self_intersec() + self.loss_last_surf()) #+ self.loss_surface()
2223+
if self.is_cellphone:
2224+
loss_reg = 0.1 * self.loss_infocus() + self.loss_ray_angle() + (self.loss_self_intersec() + self.loss_last_surf())
2225+
else:
2226+
loss_reg = 0.1 * self.loss_infocus() + self.loss_self_intersec(dist_bound=0.5, thickness_bound=0.5)
22242227

22252228
return loss_reg
22262229

deeplens/surfaces.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,22 @@ def surface_sample(self, N=1000):
275275
o2 = torch.stack((x2,y2,z2), 1).to(self.device)
276276
return o2
277277

278+
def surface(self, x, y):
279+
""" Calculate z coordinate of the surface at (x, y) with offset.
280+
281+
This function is used in lens setup plotting.
282+
"""
283+
x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device)
284+
y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device)
285+
return self.sag(x, y)
286+
278287
def surface_with_offset(self, x, y):
279288
""" Calculate z coordinate of the surface at (x, y) with offset.
280289
281290
This function is used in lens setup plotting.
282291
"""
283-
x = torch.tensor(x).to(self.device) if type(x) is float else x
284-
y = torch.tensor(y).to(self.device) if type(y) is float else y
292+
x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device)
293+
y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device)
285294
return self.sag(x, y) + self.d
286295

287296
def max_height(self):

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'transformers',
2424
'lpips',
2525
'einops',
26+
'timm',
2627
],
2728
license='Creative Commons Attribution-NonCommercial 4.0 International License',
2829
classifiers=[

0 commit comments

Comments
 (0)