diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 98a56e259..999234f8c 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -67,7 +67,7 @@ Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backi Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp) Base.convert(::Type{<:Dict}, comp::Composite{<:Dict, <:Dict}) = backing(comp) -Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx) +Base.getindex(comp::Composite, idx) = unthunk(getindex(backing(comp), idx)) # for Tuple Base.getproperty(comp::Composite, idx::Int) = unthunk(getproperty(backing(comp), idx)) @@ -82,7 +82,20 @@ end Base.keys(comp::Composite) = keys(backing(comp)) Base.propertynames(comp::Composite) = propertynames(backing(comp)) -Base.iterate(comp::Composite, args...) = iterate(backing(comp), args...) +function Base.iterate(comp::Composite, args...) + out = iterate(backing(comp), args...) + if out isa Nothing + return out + else + element, next_state = out + if comp isa Composite{<:Dict, <:Dict} + return (Pair(element.first, unthunk(element.second)), next_state) + else + return (unthunk(element), next_state) + end + end +end + Base.length(comp::Composite) = length(backing(comp)) Base.eltype(::Type{<:Composite{<:Any, T}}) where T = eltype(T) diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index cb872e9a3..f34d3537f 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -55,6 +55,8 @@ end @test getproperty(Composite{Tuple{Float64,}}(2.0), 1) == 2.0 @test getproperty(Composite{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 @test getproperty(Composite{Tuple{Float64,}}(a=(@thunk 2.0^2),), :a) == 4.0 + @test getindex(Composite{Tuple{Float64,}}(@thunk 2.0^2), 1) == 4.0 + @test getindex(Composite{Tuple{Float64,}}(a=(@thunk 2.0^2),), 1) == 4.0 @test length(Composite{Foo}(x=2.5)) == 1 @test length(Composite{Tuple{Float64,}}(2.0)) == 1 @@ -65,6 +67,9 @@ end # Testing iterate via collect @test collect(Composite{Foo}(x=2.5)) == [2.5] @test collect(Composite{Tuple{Float64,}}(2.0)) == [2.0] + @test collect(Float64, Composite{Tuple{Float64,}}(@thunk 2.0^2)) == [4.0] + @test collect(Float64, Composite{Tuple{Float64,}}(a=(@thunk 2.0^2),)) == [4.0] + @test collect(Pair{String, Float64}, Composite{Dict}(Dict([("a", @thunk 2.0^2)]))) == [Pair("a", 4.0)] end @testset "unset properties" begin