Skip to content

Commit f51144c

Browse files
authored
Primitive implementation for serialization (#258)
1 parent 5e5ac1d commit f51144c

File tree

5 files changed

+69
-5
lines changed

5 files changed

+69
-5
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.7.5"
3+
version = "0.8.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -22,6 +22,7 @@ MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
2222
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2323
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
2424
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
25+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
2526
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2627
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2728
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -67,6 +68,7 @@ MacroTools = "0.5"
6768
MetaGraphsNext = "0.6, 0.7"
6869
OrderedCollections = "1"
6970
PDMats = "0.10, 0.11"
71+
Serialization = "1.9.0"
7072
SpecialFunctions = "2"
7173
StaticArrays = "1.9"
7274
Statistics = "1.9"

src/JuliaBUGS.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using LogDensityProblems, LogDensityProblemsAD
1212
using MacroTools
1313
using OrderedCollections: OrderedDict
1414
using Random
15+
using Serialization: Serialization
1516
using StaticArrays
1617

1718
import Base: ==, hash, Symbol, size
@@ -172,7 +173,7 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N
172173
values(eval_env),
173174
),
174175
)
175-
return BUGSModel(g, nonmissing_eval_env, initial_params)
176+
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
176177
end
177178

178179
"""

src/model.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ end
4848
The `BUGSModel` object is used for inference and represents the output of compilation. It implements the
4949
[`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface.
5050
"""
51-
struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <:
52-
AbstractBUGSModel
51+
struct BUGSModel{
52+
base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV,data_T
53+
} <: AbstractBUGSModel
5354
" Indicates whether the model parameters are in the transformed space. "
5455
transformed::Bool
5556

@@ -74,6 +75,10 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,
7475

7576
"If not `Nothing`, the model is a conditioned model; otherwise, it's the model returned by `compile`."
7677
base_model::base_model_T
78+
79+
# for serialization, save the original model definition and data
80+
model_def::Expr
81+
data::data_T
7782
end
7883

7984
function Base.show(io::IO, model::BUGSModel)
@@ -137,7 +142,9 @@ variables(model::BUGSModel) = collect(labels(model.g))
137142
function BUGSModel(
138143
g::BUGSGraph,
139144
evaluation_env::NamedTuple,
140-
initial_params::NamedTuple=NamedTuple();
145+
model_def::Expr,
146+
data::NamedTuple,
147+
initial_params::NamedTuple=NamedTuple(),
141148
is_transformed::Bool=true,
142149
)
143150
flattened_graph_node_data = FlattenedGraphNodeData(g)
@@ -199,6 +206,8 @@ function BUGSModel(
199206
flattened_graph_node_data,
200207
g,
201208
nothing,
209+
model_def,
210+
data,
202211
)
203212
end
204213

@@ -220,9 +229,31 @@ function BUGSModel(
220229
FlattenedGraphNodeData(g, sorted_nodes),
221230
g,
222231
isnothing(model.base_model) ? model : model.base_model,
232+
model.model_def,
233+
model.data,
223234
)
224235
end
225236

237+
function Serialization.serialize(s::Serialization.AbstractSerializer, model::BUGSModel)
238+
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
239+
Serialization.serialize(s, typeof(model))
240+
Serialization.serialize(s, model.transformed)
241+
Serialization.serialize(s, model.model_def)
242+
Serialization.serialize(s, model.data)
243+
Serialization.serialize(s, model.evaluation_env)
244+
return nothing
245+
end
246+
247+
function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{<:BUGSModel})
248+
model_def = Serialization.deserialize(s)
249+
data = Serialization.deserialize(s)
250+
evaluation_env = Serialization.deserialize(s)
251+
transformed = Serialization.deserialize(s)
252+
# use evaluation_env as initialization to restore the values
253+
model = compile(model_def, data, evaluation_env)
254+
return settrans(model, transformed)
255+
end
256+
226257
"""
227258
initialize!(model::BUGSModel, initial_params::NamedTuple)
228259

test/model.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
@testset "serialization" begin
2+
(; model_def, data) = JuliaBUGS.BUGSExamples.rats
3+
model = compile(model_def, data)
4+
serialize("m.jls", model)
5+
deserialized = deserialize("m.jls")
6+
@testset "test values are correctly restored" begin
7+
for vn in MetaGraphsNext.labels(model.g)
8+
@test isequal(
9+
get(model.evaluation_env, vn), get(deserialized.evaluation_env, vn)
10+
)
11+
end
12+
13+
@test model.transformed == deserialized.transformed
14+
@test model.untransformed_param_length == deserialized.untransformed_param_length
15+
@test model.transformed_param_length == deserialized.transformed_param_length
16+
@test all(
17+
model.untransformed_var_lengths[k] == deserialized.untransformed_var_lengths[k]
18+
for k in keys(model.untransformed_var_lengths)
19+
)
20+
@test all(
21+
model.transformed_var_lengths[k] == deserialized.transformed_var_lengths[k] for
22+
k in keys(model.transformed_var_lengths)
23+
)
24+
@test Set(model.parameters) == Set(deserialized.parameters)
25+
# skip testing g
26+
@test model.model_def === deserialized.model_def
27+
end
28+
end
29+
130
@testset "controlling sampling behavior for conditioned variables" begin
231
model_def = @bugs begin
332
x ~ Normal(0, 1)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using MacroTools
2121
using MCMCChains
2222
using Random
2323
using ReverseDiff
24+
using Serialization
2425

2526
AbstractMCMC.setprogress!(false)
2627

0 commit comments

Comments
 (0)