Skip to content

Commit 159e600

Browse files
committed
fix physionet
1 parent 04680eb commit 159e600

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

latent_ode.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from jax.experimental.jet import jet
1717
from jax.flatten_util import ravel_pytree
1818

19-
import lib
19+
from lib.optimizers import exponential_decay
2020
from lib.ode import odeint
2121
from physionet_data import init_physionet_data
2222

@@ -57,7 +57,6 @@
5757
rng = jax.random.PRNGKey(seed)
5858
dirname = parse_args.dirname
5959
count_nfe = not parse_args.no_count_nfe
60-
num_blocks = parse_args.num_blocks
6160
ode_kwargs = {
6261
"atol": parse_args.atol,
6362
"rtol": parse_args.rtol
@@ -276,31 +275,31 @@ def init_model(gen_ode_kwargs,
276275
"b_init": jnp.zeros
277276
}
278277

279-
gru_rnn = hk.transform(wrap_module(LatentGRU,
280-
latent_dim=rec_dim,
281-
n_units=gru_units,
282-
**init_kwargs))
278+
gru_rnn = hk.without_apply_rng(hk.transform(wrap_module(LatentGRU,
279+
latent_dim=rec_dim,
280+
n_units=gru_units,
281+
**init_kwargs)))
283282
gru_rnn_params = gru_rnn.init(rng, *initialization_data_["gru_rnn"])
284283

285284
# note: the ODE-RNN version uses double
286-
rec_to_gen = hk.transform(wrap_module(lambda: hk.Sequential([
285+
rec_to_gen = hk.without_apply_rng(hk.transform(wrap_module(lambda: hk.Sequential([
287286
lambda x, y: jnp.concatenate((x, y), axis=-1),
288287
hk.Linear(50, **init_kwargs),
289288
jnp.tanh,
290289
hk.Linear(2 * gen_dim, **init_kwargs)
291-
])))
290+
]))))
292291
rec_to_gen_params = rec_to_gen.init(rng, *initialization_data_["rec_to_gen"])
293292

294-
gen_dynamics = hk.transform(wrap_module(GenDynamics,
295-
latent_dim=gen_dim,
296-
units=dynamics_units,
297-
layers=gen_layers))
293+
gen_dynamics = hk.without_apply_rng(hk.transform(wrap_module(GenDynamics,
294+
latent_dim=gen_dim,
295+
units=dynamics_units,
296+
layers=gen_layers)))
298297
gen_dynamics_params = gen_dynamics.init(rng, *initialization_data_["gen_dynamics"])
299298
gen_dynamics_wrap = lambda x, t, params: gen_dynamics.apply(params, x, t)
300299

301-
gen_to_data = hk.transform(wrap_module(hk.Linear,
302-
output_size=data_dim,
303-
**init_kwargs))
300+
gen_to_data = hk.without_apply_rng(hk.transform(wrap_module(hk.Linear,
301+
output_size=data_dim,
302+
**init_kwargs)))
304303
gen_to_data_params = gen_to_data.init(rng, initialization_data_["gen_to_data"])
305304

306305
init_params = {
@@ -310,7 +309,7 @@ def init_model(gen_ode_kwargs,
310309
"gen_to_data": gen_to_data_params
311310
}
312311

313-
def forward(count_nfe_, reg, _method, params, data, data_timesteps, timesteps, mask, num_samples=3):
312+
def forward(count_nfe_, params, data, data_timesteps, timesteps, mask, num_samples=3):
314313
"""
315314
Forward pass of the model.
316315
y are the latent variables of the recognition model
@@ -343,7 +342,7 @@ def integrate_sample(z0_):
343342
dynamics = gen_dynamics_wrap
344343
init_fn = lambda x: x
345344
else:
346-
dynamics = augment_dynamics(gen_dynamics_wrap, reg)
345+
dynamics = augment_dynamics(gen_dynamics_wrap)
347346
init_fn = aug_init
348347
return jax.vmap(lambda z_, t_: odeint(dynamics, init_fn(z_), t_,
349348
params["gen_dynamics"], **gen_ode_kwargs),
@@ -391,7 +390,7 @@ def scan_fun(prev_state, xi):
391390
model = {
392391
"forward": partial(forward, False),
393392
"params": init_params,
394-
"nfe": lambda *args: partial(forward, count_nfe, reg)(*args)[-1]
393+
"nfe": lambda *args: partial(forward, count_nfe)(*args)[-1]
395394
}
396395

397396
return model
@@ -474,10 +473,10 @@ def run():
474473
forward = lambda *args: model["forward"](*args)[1:]
475474
grad_fn = jax.grad(lambda *args: loss_fn(forward, *args))
476475

477-
lr_schedule = lib.optimizers.exponential_decay(step_size=parse_args.lr,
478-
decay_steps=1,
479-
decay_rate=0.999,
480-
lowest=parse_args.lr / 10)
476+
lr_schedule = exponential_decay(step_size=parse_args.lr,
477+
decay_steps=1,
478+
decay_rate=0.999,
479+
lowest=parse_args.lr / 10)
481480
opt_init, opt_update, get_params = optimizers.adamax(step_size=lr_schedule)
482481
opt_state = opt_init(model["params"])
483482

@@ -574,7 +573,7 @@ def evaluate_loss(opt_state, ds_test, kl_coef):
574573

575574
print(print_str)
576575

577-
outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_info.txt" % (dirname, reg, lam, num_blocks), "a")
576+
outfile = open("%s/reg_%s_lam_%.12e_info.txt" % (dirname, reg, lam), "a")
578577
outfile.write(print_str + "\n")
579578
outfile.close()
580579

@@ -594,14 +593,14 @@ def evaluate_loss(opt_state, ds_test, kl_coef):
594593
pickle.dump(fargs, outfile)
595594
outfile.close()
596595

597-
outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_iter.txt" % (dirname, reg, lam, num_blocks), "a")
596+
outfile = open("%s/reg_%s_lam_%.12e_iter.txt" % (dirname, reg, lam), "a")
598597
outfile.write("Iter: {:04d}\n".format(itr))
599598
outfile.close()
600599
meta = {
601600
"info": info,
602601
"args": parse_args
603602
}
604-
outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_meta.pickle" % (dirname, reg, lam, num_blocks), "wb")
603+
outfile = open("%s/reg_%s_lam_%.12e_meta.pickle" % (dirname, reg, lam), "wb")
605604
pickle.dump(meta, outfile)
606605
outfile.close()
607606

0 commit comments

Comments
 (0)