Skip to content

Commit 7be9556

Browse files
authored
Normalise optics when constructing VarName; remove extra constructors (#123)
* Normalise optics when constructing VarName; remove deprecated methods * Add another test * Improve changelog * Add clarifying comment
1 parent a39791b commit 7be9556

File tree

8 files changed

+152
-93
lines changed

8 files changed

+152
-93
lines changed

HISTORY.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## 0.12.0
2+
3+
### VarName constructors
4+
5+
Removed the constructors `VarName(vn, optic)` (this wasn't deprecated, but was dangerous as it would silently discard the existing optic in `vn`), and `VarName(vn, ::Tuple)` (which was deprecated).
6+
7+
Usage of `VarName(vn, optic)` can be directly replaced with `VarName{getsym(vn)}(optic)`.
8+
9+
### Optic normalisation
10+
11+
In the inner constructor of a VarName, its optic is now normalised to ensure that the associativity of ComposedFunction is always the same, and that compositions with identity are removed.
12+
This helps to prevent subtle bugs where VarNames with semantically equal optics are not considered equal.
13+
14+
## 0.11.0
15+
16+
Added the `prefix(vn::VarName, vn_prefix::VarName)` and `unprefix(vn::VarName, vn_prefix::VarName)` functions.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.11.0"
6+
version = "0.12.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/AbstractPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,5 @@ include("varname.jl")
2929
include("abstractmodeltrace.jl")
3030
include("abstractprobprog.jl")
3131
include("evaluate.jl")
32-
include("deprecations.jl")
3332

3433
end # module

src/deprecations.jl

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/varname.jl

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using Accessors
2-
using Accessors: ComposedOptic, PropertyLens, IndexLens, DynamicIndexLens
2+
using Accessors: PropertyLens, IndexLens, DynamicIndexLens
33
using JSON: JSON
44

5-
const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedOptic}
5+
# nb. ComposedFunction is the same as Accessors.ComposedOptic
6+
const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction}
67

78
"""
89
VarName{sym}(optic=identity)
@@ -31,10 +32,11 @@ julia> @varname x[:, 1][1+1]
3132
x[:, 1][2]
3233
```
3334
"""
34-
struct VarName{sym,T}
35+
struct VarName{sym,T<:ALLOWED_OPTICS}
3536
optic::T
3637

3738
function VarName{sym}(optic=identity) where {sym}
39+
optic = normalise(optic)
3840
if !is_static_optic(typeof(optic))
3941
throw(
4042
ArgumentError(
@@ -53,42 +55,68 @@ Return `true` if `l` is one or a composition of `identity`, `PropertyLens`, and
5355
one or a composition of `DynamicIndexLens`; and undefined otherwise.
5456
"""
5557
is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true
56-
function is_static_optic(::Type{ComposedOptic{LO,LI}}) where {LO,LI}
58+
function is_static_optic(::Type{ComposedFunction{LO,LI}}) where {LO,LI}
5759
return is_static_optic(LO) && is_static_optic(LI)
5860
end
5961
is_static_optic(::Type{<:DynamicIndexLens}) = false
6062

61-
# A bit of backwards compatibility.
62-
VarName{sym}(indexing::Tuple) where {sym} = VarName{sym}(tupleindex2optic(indexing))
63-
6463
"""
65-
VarName(vn::VarName, optic)
66-
VarName(vn::VarName, indexing::Tuple)
64+
normalise(optic)
6765
68-
Return a copy of `vn` with a new index `optic`/`indexing`.
66+
Enforce that compositions of optics are always nested in the same way, in that
67+
a ComposedFunction never has a ComposedFunction as its inner lens. Thus, for
68+
example,
6969
7070
```jldoctest; setup=:(using Accessors)
71-
julia> VarName(@varname(x[1][2:3]), Accessors.IndexLens((2,)))
72-
x[2]
71+
julia> op1 = ((@o _.c) ∘ (@o _.b)) ∘ (@o _.a)
72+
(@o _.a.b.c)
7373
74-
julia> VarName(@varname(x[1][2:3]), ((2,),))
75-
x[2]
74+
julia> op2 = (@o _.c) ∘ ((@o _.b) ∘ (@o _.a))
75+
(@o _.c) ∘ ((@o _.a.b))
7676
77-
julia> VarName(@varname(x[1][2:3]))
78-
x
77+
julia> op1 == op2
78+
false
79+
80+
julia> AbstractPPL.normalise(op1) == AbstractPPL.normalise(op2) == @o _.a.b.c
81+
true
7982
```
80-
"""
81-
VarName(vn::VarName, optic=identity) = VarName{getsym(vn)}(optic)
8283
83-
function VarName(vn::VarName, indexing::Tuple)
84-
return VarName{getsym(vn)}(tupleindex2optic(indexing))
85-
end
84+
This function also removes redundant `identity` optics from ComposedFunctions:
85+
86+
```jldoctest; setup=:(using Accessors)
87+
julia> op3 = ((@o _.b) ∘ identity) ∘ (@o _.a)
88+
(@o identity(_.a).b)
8689
87-
tupleindex2optic(indexing::Tuple{}) = identity
88-
tupleindex2optic(indexing::Tuple{<:Tuple}) = IndexLens(first(indexing)) # TODO: rest?
89-
function tupleindex2optic(indexing::Tuple)
90-
return IndexLens(first(indexing)) tupleindex2optic(indexing[2:end])
90+
julia> op4 = (@o _.b) ∘ (identity ∘ (@o _.a))
91+
(@o _.b) ∘ ((@o identity(_.a)))
92+
93+
julia> AbstractPPL.normalise(op3) == AbstractPPL.normalise(op4) == @o _.a.b
94+
true
95+
```
96+
"""
97+
function normalise(o::ComposedFunction{Outer,<:ComposedFunction}) where {Outer}
98+
# `o` is currently (outer ∘ (inner_outer ∘ inner_inner)).
99+
# We want to change this to:
100+
# o = (outer ∘ inner_outer) ∘ inner_inner
101+
inner_inner = o.inner.inner
102+
inner_outer = o.inner.outer
103+
# Recursively call normalise because inner_inner could itself be a
104+
# ComposedFunction
105+
return normalise((o.outer inner_outer) inner_inner)
106+
end
107+
function normalise(o::ComposedFunction{Outer,typeof(identity)} where {Outer})
108+
# strip outer identity
109+
return normalise(o.outer)
110+
end
111+
function normalise(o::ComposedFunction{typeof(identity),Inner} where {Inner})
112+
# strip inner identity
113+
return normalise(o.inner)
91114
end
115+
normalise(o::ComposedFunction) = normalise(o.outer) o.inner
116+
normalise(o::ALLOWED_OPTICS) = o
117+
# These two methods are needed to avoid method ambiguity.
118+
normalise(o::ComposedFunction{typeof(identity),<:ComposedFunction}) = normalise(o.inner)
119+
normalise(::ComposedFunction{typeof(identity),typeof(identity)}) = identity
92120

93121
"""
94122
getsym(vn::VarName)
@@ -105,7 +133,7 @@ julia> getsym(@varname(y))
105133
:y
106134
```
107135
"""
108-
getsym(vn::VarName{sym}) where {sym} = sym
136+
getsym(::VarName{sym}) where {sym} = sym
109137

110138
"""
111139
getoptic(vn::VarName)
@@ -154,15 +182,8 @@ function Accessors.set(obj, vn::VarName{sym}, value) where {sym}
154182
end
155183

156184
# Allow compositions with optic.
157-
function Base.:(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym}
158-
vn_optic = getoptic(vn)
159-
if vn_optic == identity
160-
return VarName{sym}(optic)
161-
elseif optic == identity
162-
return vn
163-
else
164-
return VarName{sym}(optic vn_optic)
165-
end
185+
function Base.:(optic::ALLOWED_OPTICS, vn::VarName{sym}) where {sym}
186+
return VarName{sym}(optic getoptic(vn))
166187
end
167188

168189
Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h)
@@ -299,17 +320,17 @@ subsumes(::typeof(identity), ::typeof(identity)) = true
299320
subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true
300321
subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false
301322

302-
function subsumes(t::ComposedOptic, u::ComposedOptic)
323+
function subsumes(t::ComposedFunction, u::ComposedFunction)
303324
return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
304325
end
305326

306327
# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a
307328
# leaf of the "lens-tree".
308-
subsumes(t::ComposedOptic, u::PropertyLens) = false
329+
subsumes(t::ComposedFunction, u::PropertyLens) = false
309330
# Here we need to check if `u.inner` (i.e. the next lens to be applied from `u`) is
310331
# subsumed by `t`, since this would mean that the rest of the composition is also subsumed
311332
# by `t`.
312-
subsumes(t::PropertyLens, u::ComposedOptic) = subsumes(t, u.inner)
333+
subsumes(t::PropertyLens, u::ComposedFunction) = subsumes(t, u.inner)
313334

314335
# For `PropertyLens` either they have the same `name` and thus they are indeed the same.
315336
subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true
@@ -321,8 +342,8 @@ subsumes(t::PropertyLens, u::PropertyLens) = false
321342
# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])`
322343
# (but neither did old implementation).
323344
function subsumes(
324-
t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
325-
u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
345+
t::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}},
346+
u::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}},
326347
)
327348
return subsumes_indices(t, u)
328349
end
@@ -415,7 +436,7 @@ The result is compatible with [`subsumes_indices`](@ref) for `Tuple` input.
415436
"""
416437
combine_indices(optic::ALLOWED_OPTICS) = (), optic
417438
combine_indices(optic::IndexLens) = (optic.indices,), nothing
418-
function combine_indices(optic::ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens})
439+
function combine_indices(optic::ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens})
419440
indices, next = combine_indices(optic.outer)
420441
return (optic.inner.indices, indices...), next
421442
end
@@ -505,9 +526,9 @@ concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x)
505526
function concretize(I::IndexLens, x)
506527
return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
507528
end
508-
function concretize(I::ComposedOptic, x)
529+
function concretize(I::ComposedFunction, x)
509530
x_inner = I.inner(x) # TODO: get view here
510-
return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x))
531+
return ComposedFunction(concretize(I.outer, x_inner), concretize(I.inner, x))
511532
end
512533

513534
"""
@@ -533,7 +554,7 @@ julia> # The underlying value is concretized, though:
533554
ConcretizedSlice(Base.OneTo(100))
534555
```
535556
"""
536-
concretize(vn::VarName, x) = VarName(vn, concretize(getoptic(vn), x))
557+
concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x))
537558

538559
"""
539560
@varname(expr, concretize=false)
@@ -872,7 +893,7 @@ function optic_to_dict(::PropertyLens{sym}) where {sym}
872893
return Dict("type" => "property", "field" => String(sym))
873894
end
874895
optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices))
875-
function optic_to_dict(c::ComposedOptic)
896+
function optic_to_dict(c::ComposedFunction)
876897
return Dict(
877898
"type" => "composed",
878899
"outer" => optic_to_dict(c.outer),
@@ -1036,32 +1057,34 @@ ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarN
10361057
function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym}
10371058
return VarName{sym}()
10381059
end
1039-
function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
1040-
return optic_to_vn(o.outer)
1041-
end
10421060
function optic_to_vn(
10431061
o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}}
10441062
) where {Outer,sym}
10451063
return VarName{sym}(o.outer)
10461064
end
1065+
optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o))
10471066
function optic_to_vn(@nospecialize(o))
10481067
msg = "optic_to_vn: could not convert optic `$o` to a VarName"
10491068
throw(ArgumentError(msg))
10501069
end
10511070

10521071
unprefix_optic(o, ::typeof(identity)) = o # Base case
10531072
function unprefix_optic(optic, optic_prefix)
1073+
# Technically `unprefix_optic` only receives optics that were part of
1074+
# VarNames, so the optics should already be normalised (in the inner
1075+
# constructor of the VarName). However I guess it doesn't hurt to do it
1076+
# again to be safe.
1077+
optic = normalise(optic)
1078+
optic_prefix = normalise(optic_prefix)
10541079
# strip one layer of the optic and check for equality
1055-
inner = _inner(_strip_identity(optic))
1056-
inner_prefix = _inner(_strip_identity(optic_prefix))
1080+
inner = _inner(optic)
1081+
inner_prefix = _inner(optic_prefix)
10571082
if inner != inner_prefix
10581083
msg = "could not remove prefix $(optic_prefix) from optic $(optic)"
10591084
throw(ArgumentError(msg))
10601085
end
10611086
# recurse
1062-
return unprefix_optic(
1063-
_outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix))
1064-
)
1087+
return unprefix_optic(_outer(optic), _outer(optic_prefix))
10651088
end
10661089

10671090
"""
@@ -1115,16 +1138,6 @@ y[1].x.a
11151138
function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix}
11161139
optic_vn = getoptic(vn)
11171140
optic_prefix = getoptic(prefix)
1118-
# Special case `identity` to avoid having ComposedFunctions with identity
1119-
if optic_vn == identity
1120-
new_inner_optic_vn = PropertyLens{sym_vn}()
1121-
else
1122-
new_inner_optic_vn = optic_vn PropertyLens{sym_vn}()
1123-
end
1124-
if optic_prefix == identity
1125-
new_optic_vn = new_inner_optic_vn
1126-
else
1127-
new_optic_vn = new_inner_optic_vn optic_prefix
1128-
end
1141+
new_optic_vn = optic_vn PropertyLens{sym_vn}() optic_prefix
11291142
return VarName{sym_prefix}(new_optic_vn)
11301143
end

test/deprecations.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

test/runtests.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
# Activate test environment on older Julia versions
2-
if VERSION < v"1.2"
3-
using Pkg: Pkg
4-
Pkg.activate(@__DIR__)
5-
Pkg.develop(Pkg.PackageSpec(; path=dirname(@__DIR__)))
6-
Pkg.instantiate()
7-
end
8-
91
using AbstractPPL
102
using Documenter
113
using Test
@@ -14,7 +6,6 @@ const GROUP = get(ENV, "GROUP", "All")
146

157
@testset "AbstractPPL.jl" begin
168
if GROUP == "All" || GROUP == "Tests"
17-
include("deprecations.jl")
189
include("varname.jl")
1910
include("abstractprobprog.jl")
2011
end

0 commit comments

Comments
 (0)