@@ -539,3 +539,67 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect
539
539
end
540
540
return evaluation_env, logp
541
541
end
542
+
543
+ """
544
+ _tempered_evaluate!!(model::BUGSModel, flattened_values::AbstractVector; temperature=1.0)
545
+
546
+ Evaluating the model with the given model parameter values, returns updated evaluation environment
547
+ and a NamedTuple of logprior, loglikelihood and tempered logjoint (where tempered logjoint is the logjoint
548
+ whose loglikelihood component scaled by the given temperature).
549
+ """
550
+ function _tempered_evaluate!! (
551
+ model:: BUGSModel , flattened_values:: AbstractVector ; temperature= 1.0
552
+ )
553
+ var_lengths = if model. transformed
554
+ model. transformed_var_lengths
555
+ else
556
+ model. untransformed_var_lengths
557
+ end
558
+
559
+ evaluation_env = deepcopy (model. evaluation_env)
560
+ current_idx = 1
561
+ logprior, loglikelihood = 0.0 , 0.0
562
+ for (i, vn) in enumerate (model. flattened_graph_node_data. sorted_nodes)
563
+ is_stochastic = model. flattened_graph_node_data. is_stochastic_vals[i]
564
+ is_observed = model. flattened_graph_node_data. is_observed_vals[i]
565
+ node_function = model. flattened_graph_node_data. node_function_vals[i]
566
+ loop_vars = model. flattened_graph_node_data. loop_vars_vals[i]
567
+ if ! is_stochastic
568
+ value = node_function (evaluation_env, loop_vars)
569
+ evaluation_env = BangBang. setindex!! (evaluation_env, value, vn)
570
+ else
571
+ dist = node_function (evaluation_env, loop_vars)
572
+ if ! is_observed
573
+ l = var_lengths[vn]
574
+ if model. transformed
575
+ b = Bijectors. bijector (dist)
576
+ b_inv = Bijectors. inverse (b)
577
+ reconstructed_value = reconstruct (
578
+ b_inv,
579
+ dist,
580
+ view (flattened_values, current_idx: (current_idx + l - 1 )),
581
+ )
582
+ value, logjac = Bijectors. with_logabsdet_jacobian (
583
+ b_inv, reconstructed_value
584
+ )
585
+ else
586
+ value = reconstruct (
587
+ dist, view (flattened_values, current_idx: (current_idx + l - 1 ))
588
+ )
589
+ logjac = 0.0
590
+ end
591
+ current_idx += l
592
+ logprior += logpdf (dist, value) + logjac
593
+ evaluation_env = BangBang. setindex!! (evaluation_env, value, vn)
594
+ else
595
+ loglikelihood += logpdf (dist, AbstractPPL. get (evaluation_env, vn))
596
+ end
597
+ end
598
+ end
599
+ return evaluation_env,
600
+ (
601
+ logprior= logprior,
602
+ loglikelihood= loglikelihood,
603
+ tempered_logjoint= logprior + temperature * loglikelihood,
604
+ )
605
+ end
0 commit comments