Skip to content

Commit c094f3a

Browse files
authored
Implement varname prefix / unprefix (#119)
* Implement prefix / unprefix * Document * Clean up code a bit * Add another test * Remove unused method * Bump version
1 parent 700a70b commit c094f3a

File tree

5 files changed

+218
-2
lines changed

5 files changed

+218
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.10.1"
6+
version = "0.11.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ vsym
1414
@vsym
1515
```
1616

17+
## VarName prefixing and unprefixing
18+
19+
```@docs
20+
prefix
21+
unprefix
22+
```
23+
1724
## VarName serialisation
1825

1926
```@docs

src/AbstractPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ export VarName,
1414
index_to_dict,
1515
dict_to_index,
1616
varname_to_string,
17-
string_to_varname
17+
string_to_varname,
18+
prefix,
19+
unprefix
1820

1921
# Abstract model functions
2022
export AbstractProbabilisticProgram,

src/varname.jl

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,8 @@ function vsym(expr::Expr)
766766
end
767767
end
768768

769+
### Serialisation to JSON / string
770+
769771
# String constants for each index type that we support serialisation /
770772
# deserialisation of
771773
const _BASE_INTEGER_TYPE = "Base.Integer"
@@ -936,3 +938,193 @@ Convert a string representation of a `VarName` back to a `VarName`. The string
936938
should have been generated by `varname_to_string`.
937939
"""
938940
string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str))
941+
942+
### Prefixing and unprefixing
943+
944+
"""
945+
_strip_identity(optic)
946+
947+
Remove identity lenses from composed optics.
948+
"""
949+
_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity
950+
function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
951+
return _strip_identity(o.outer)
952+
end
953+
function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner}
954+
return _strip_identity(o.inner)
955+
end
956+
_strip_identity(o::Base.ComposedFunction) = o
957+
_strip_identity(o::Accessors.PropertyLens) = o
958+
_strip_identity(o::Accessors.IndexLens) = o
959+
_strip_identity(o::typeof(identity)) = o
960+
961+
"""
962+
_inner(optic)
963+
964+
Get the innermost (non-identity) layer of an optic.
965+
966+
```jldoctest; setup=:(using Accessors)
967+
julia> AbstractPPL._inner(Accessors.@o _.a.b.c)
968+
(@o _.a)
969+
970+
julia> AbstractPPL._inner(Accessors.@o _[1][2][3])
971+
(@o _[1])
972+
973+
julia> AbstractPPL._inner(Accessors.@o _)
974+
identity (generic function with 1 method)
975+
```
976+
"""
977+
_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
978+
_inner(o::Accessors.PropertyLens) = o
979+
_inner(o::Accessors.IndexLens) = o
980+
_inner(o::typeof(identity)) = o
981+
982+
"""
983+
_outer(optic)
984+
985+
Get the outer layer of an optic.
986+
987+
```jldoctest; setup=:(using Accessors)
988+
julia> AbstractPPL._outer(Accessors.@o _.a.b.c)
989+
(@o _.b.c)
990+
991+
julia> AbstractPPL._outer(Accessors.@o _[1][2][3])
992+
(@o _[2][3])
993+
994+
julia> AbstractPPL._outer(Accessors.@o _.a)
995+
identity (generic function with 1 method)
996+
997+
julia> AbstractPPL._outer(Accessors.@o _[1])
998+
identity (generic function with 1 method)
999+
1000+
julia> AbstractPPL._outer(Accessors.@o _)
1001+
identity (generic function with 1 method)
1002+
```
1003+
"""
1004+
_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer
1005+
_outer(::Accessors.PropertyLens) = identity
1006+
_outer(::Accessors.IndexLens) = identity
1007+
_outer(::typeof(identity)) = identity
1008+
1009+
"""
1010+
optic_to_vn(optic)
1011+
1012+
Convert an Accessors optic to a VarName. This is best explained through
1013+
examples.
1014+
1015+
```jldoctest; setup=:(using Accessors)
1016+
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a)
1017+
a
1018+
1019+
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b)
1020+
a.b
1021+
1022+
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1])
1023+
a[1]
1024+
```
1025+
1026+
The outermost layer of the optic (technically, what Accessors.jl calls the
1027+
'innermost') must be a `PropertyLens`, or else it will fail. This is because a
1028+
VarName needs to have a symbol.
1029+
1030+
```jldoctest; setup=:(using Accessors)
1031+
julia> AbstractPPL.optic_to_vn(Accessors.@o _[1])
1032+
ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName
1033+
[...]
1034+
```
1035+
"""
1036+
function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym}
1037+
return VarName{sym}()
1038+
end
1039+
function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
1040+
return optic_to_vn(o.outer)
1041+
end
1042+
function optic_to_vn(
1043+
o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}}
1044+
) where {Outer,sym}
1045+
return VarName{sym}(o.outer)
1046+
end
1047+
function optic_to_vn(@nospecialize(o))
1048+
msg = "optic_to_vn: could not convert optic `$o` to a VarName"
1049+
throw(ArgumentError(msg))
1050+
end
1051+
1052+
unprefix_optic(o, ::typeof(identity)) = o # Base case
1053+
function unprefix_optic(optic, optic_prefix)
1054+
# strip one layer of the optic and check for equality
1055+
inner = _inner(_strip_identity(optic))
1056+
inner_prefix = _inner(_strip_identity(optic_prefix))
1057+
if inner != inner_prefix
1058+
msg = "could not remove prefix $(optic_prefix) from optic $(optic)"
1059+
throw(ArgumentError(msg))
1060+
end
1061+
# recurse
1062+
return unprefix_optic(
1063+
_outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix))
1064+
)
1065+
end
1066+
1067+
"""
1068+
unprefix(vn::VarName, prefix::VarName)
1069+
1070+
Remove a prefix from a VarName.
1071+
1072+
```jldoctest
1073+
julia> AbstractPPL.unprefix(@varname(y.x), @varname(y))
1074+
x
1075+
1076+
julia> AbstractPPL.unprefix(@varname(y.x.a), @varname(y))
1077+
x.a
1078+
1079+
julia> AbstractPPL.unprefix(@varname(y[1].x), @varname(y[1]))
1080+
x
1081+
1082+
julia> AbstractPPL.unprefix(@varname(y), @varname(n))
1083+
ERROR: ArgumentError: could not remove prefix n from VarName y
1084+
[...]
1085+
```
1086+
"""
1087+
function unprefix(
1088+
vn::VarName{sym_vn}, prefix::VarName{sym_prefix}
1089+
) where {sym_vn,sym_prefix}
1090+
if sym_vn != sym_prefix
1091+
msg = "could not remove prefix $(prefix) from VarName $(vn)"
1092+
throw(ArgumentError(msg))
1093+
end
1094+
optic_vn = getoptic(vn)
1095+
optic_prefix = getoptic(prefix)
1096+
return optic_to_vn(unprefix_optic(optic_vn, optic_prefix))
1097+
end
1098+
1099+
"""
1100+
prefix(vn::VarName, prefix::VarName)
1101+
1102+
Add a prefix to a VarName.
1103+
1104+
```jldoctest
1105+
julia> AbstractPPL.prefix(@varname(x), @varname(y))
1106+
y.x
1107+
1108+
julia> AbstractPPL.prefix(@varname(x.a), @varname(y))
1109+
y.x.a
1110+
1111+
julia> AbstractPPL.prefix(@varname(x.a), @varname(y[1]))
1112+
y[1].x.a
1113+
```
1114+
"""
1115+
function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix}
1116+
optic_vn = getoptic(vn)
1117+
optic_prefix = getoptic(prefix)
1118+
# Special case `identity` to avoid having ComposedFunctions with identity
1119+
if optic_vn == identity
1120+
new_inner_optic_vn = PropertyLens{sym_vn}()
1121+
else
1122+
new_inner_optic_vn = optic_vn PropertyLens{sym_vn}()
1123+
end
1124+
if optic_prefix == identity
1125+
new_optic_vn = new_inner_optic_vn
1126+
else
1127+
new_optic_vn = new_inner_optic_vn optic_prefix
1128+
end
1129+
return VarName{sym_prefix}(new_optic_vn)
1130+
end

test/varname.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,19 @@ end
233233
# Serialisation should now work
234234
@test string_to_varname(varname_to_string(vn)) == vn
235235
end
236+
237+
@testset "prefix and unprefix" begin
238+
@test prefix(@varname(y), @varname(x)) == @varname(x.y)
239+
@test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y)
240+
@test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y)
241+
@test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1])
242+
@test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a)
243+
244+
@test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1])
245+
@test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y)
246+
@test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y)
247+
@test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a)
248+
@test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n))
249+
@test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1]))
250+
end
236251
end

0 commit comments

Comments
 (0)