16
16
from jax .experimental .jet import jet
17
17
from jax .flatten_util import ravel_pytree
18
18
19
- import lib
19
+ from lib . optimizers import exponential_decay
20
20
from lib .ode import odeint
21
21
from physionet_data import init_physionet_data
22
22
57
57
rng = jax .random .PRNGKey (seed )
58
58
dirname = parse_args .dirname
59
59
count_nfe = not parse_args .no_count_nfe
60
- num_blocks = parse_args .num_blocks
61
60
ode_kwargs = {
62
61
"atol" : parse_args .atol ,
63
62
"rtol" : parse_args .rtol
@@ -276,31 +275,31 @@ def init_model(gen_ode_kwargs,
276
275
"b_init" : jnp .zeros
277
276
}
278
277
279
- gru_rnn = hk .transform (wrap_module (LatentGRU ,
280
- latent_dim = rec_dim ,
281
- n_units = gru_units ,
282
- ** init_kwargs ))
278
+ gru_rnn = hk .without_apply_rng ( hk . transform (wrap_module (LatentGRU ,
279
+ latent_dim = rec_dim ,
280
+ n_units = gru_units ,
281
+ ** init_kwargs ) ))
283
282
gru_rnn_params = gru_rnn .init (rng , * initialization_data_ ["gru_rnn" ])
284
283
285
284
# note: the ODE-RNN version uses double
286
- rec_to_gen = hk .transform (wrap_module (lambda : hk .Sequential ([
285
+ rec_to_gen = hk .without_apply_rng ( hk . transform (wrap_module (lambda : hk .Sequential ([
287
286
lambda x , y : jnp .concatenate ((x , y ), axis = - 1 ),
288
287
hk .Linear (50 , ** init_kwargs ),
289
288
jnp .tanh ,
290
289
hk .Linear (2 * gen_dim , ** init_kwargs )
291
- ])))
290
+ ]))))
292
291
rec_to_gen_params = rec_to_gen .init (rng , * initialization_data_ ["rec_to_gen" ])
293
292
294
- gen_dynamics = hk .transform (wrap_module (GenDynamics ,
295
- latent_dim = gen_dim ,
296
- units = dynamics_units ,
297
- layers = gen_layers ))
293
+ gen_dynamics = hk .without_apply_rng ( hk . transform (wrap_module (GenDynamics ,
294
+ latent_dim = gen_dim ,
295
+ units = dynamics_units ,
296
+ layers = gen_layers ) ))
298
297
gen_dynamics_params = gen_dynamics .init (rng , * initialization_data_ ["gen_dynamics" ])
299
298
gen_dynamics_wrap = lambda x , t , params : gen_dynamics .apply (params , x , t )
300
299
301
- gen_to_data = hk .transform (wrap_module (hk .Linear ,
302
- output_size = data_dim ,
303
- ** init_kwargs ))
300
+ gen_to_data = hk .without_apply_rng ( hk . transform (wrap_module (hk .Linear ,
301
+ output_size = data_dim ,
302
+ ** init_kwargs ) ))
304
303
gen_to_data_params = gen_to_data .init (rng , initialization_data_ ["gen_to_data" ])
305
304
306
305
init_params = {
@@ -310,7 +309,7 @@ def init_model(gen_ode_kwargs,
310
309
"gen_to_data" : gen_to_data_params
311
310
}
312
311
313
- def forward (count_nfe_ , reg , _method , params , data , data_timesteps , timesteps , mask , num_samples = 3 ):
312
+ def forward (count_nfe_ , params , data , data_timesteps , timesteps , mask , num_samples = 3 ):
314
313
"""
315
314
Forward pass of the model.
316
315
y are the latent variables of the recognition model
@@ -343,7 +342,7 @@ def integrate_sample(z0_):
343
342
dynamics = gen_dynamics_wrap
344
343
init_fn = lambda x : x
345
344
else :
346
- dynamics = augment_dynamics (gen_dynamics_wrap , reg )
345
+ dynamics = augment_dynamics (gen_dynamics_wrap )
347
346
init_fn = aug_init
348
347
return jax .vmap (lambda z_ , t_ : odeint (dynamics , init_fn (z_ ), t_ ,
349
348
params ["gen_dynamics" ], ** gen_ode_kwargs ),
@@ -391,7 +390,7 @@ def scan_fun(prev_state, xi):
391
390
model = {
392
391
"forward" : partial (forward , False ),
393
392
"params" : init_params ,
394
- "nfe" : lambda * args : partial (forward , count_nfe , reg )(* args )[- 1 ]
393
+ "nfe" : lambda * args : partial (forward , count_nfe )(* args )[- 1 ]
395
394
}
396
395
397
396
return model
@@ -474,10 +473,10 @@ def run():
474
473
forward = lambda * args : model ["forward" ](* args )[1 :]
475
474
grad_fn = jax .grad (lambda * args : loss_fn (forward , * args ))
476
475
477
- lr_schedule = lib . optimizers . exponential_decay (step_size = parse_args .lr ,
478
- decay_steps = 1 ,
479
- decay_rate = 0.999 ,
480
- lowest = parse_args .lr / 10 )
476
+ lr_schedule = exponential_decay (step_size = parse_args .lr ,
477
+ decay_steps = 1 ,
478
+ decay_rate = 0.999 ,
479
+ lowest = parse_args .lr / 10 )
481
480
opt_init , opt_update , get_params = optimizers .adamax (step_size = lr_schedule )
482
481
opt_state = opt_init (model ["params" ])
483
482
@@ -574,7 +573,7 @@ def evaluate_loss(opt_state, ds_test, kl_coef):
574
573
575
574
print (print_str )
576
575
577
- outfile = open ("%s/reg_%s_lam_%.12e_num_blocks_%d_info .txt" % (dirname , reg , lam , num_blocks ), "a" )
576
+ outfile = open ("%s/reg_%s_lam_%.12e_info .txt" % (dirname , reg , lam ), "a" )
578
577
outfile .write (print_str + "\n " )
579
578
outfile .close ()
580
579
@@ -594,14 +593,14 @@ def evaluate_loss(opt_state, ds_test, kl_coef):
594
593
pickle .dump (fargs , outfile )
595
594
outfile .close ()
596
595
597
- outfile = open ("%s/reg_%s_lam_%.12e_num_blocks_%d_iter .txt" % (dirname , reg , lam , num_blocks ), "a" )
596
+ outfile = open ("%s/reg_%s_lam_%.12e_iter .txt" % (dirname , reg , lam ), "a" )
598
597
outfile .write ("Iter: {:04d}\n " .format (itr ))
599
598
outfile .close ()
600
599
meta = {
601
600
"info" : info ,
602
601
"args" : parse_args
603
602
}
604
- outfile = open ("%s/reg_%s_lam_%.12e_num_blocks_%d_meta .pickle" % (dirname , reg , lam , num_blocks ), "wb" )
603
+ outfile = open ("%s/reg_%s_lam_%.12e_meta .pickle" % (dirname , reg , lam ), "wb" )
605
604
pickle .dump (meta , outfile )
606
605
outfile .close ()
607
606
0 commit comments