-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
-----------------------------------
# Define DIP model
# -----------------------------------
model_shape = [nz,nx]
DIP_model = DIP_CNN(model_shape,in_channels=[16,32,16],vmin=vp_true.min()/1000,vmax=vp_true.max()/1000,device=device)
DIP_model.to(device)
# -----------------------------------
# Pretrain DIP model
# -----------------------------------
pretrain = True
load_pretrained = False
if pretrain:
if load_pretrained:
# load the model parameters
DIP_model.load_state_dict(torch.load(os.path.join(project_path,f"inversion-{layer_num}layer-16-32-16/DIP_model_pretrained.pt")))
else:
lr = 0.005
iteration = 10000
step_size = 1000
gamma = 0.5
optimizer = torch.optim.Adam(DIP_model.parameters(),lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=step_size,gamma=gamma)
vp_init = numpy2tensor(vp_init,dtype=dtype).to(device)
pbar = tqdm(range(iteration+1))
for i in pbar:
vp_nn = DIP_model()
loss = torch.sqrt(torch.sum((vp_nn - vp_init)**2))
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
pbar.set_description(f'Pretrain Iter:{i}, Misfit:{loss.cpu().detach().numpy()}')
torch.save(DIP_model.state_dict(),os.path.join(project_path,f"inversion-{layer_num}layer-16-32-16/DIP_model_pretrained.pt"))
I'm trying to modify the CNN module (typical of the code above) in ADFWI to work for PINN, I need guide on how to effectively do this and also modify the loss function to account for model penalty of interest.
Metadata
Metadata
Assignees
Labels
No labels