diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index c5586f80f..b0c4a9f66 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -305,6 +305,7 @@ closure approach will be used. By default, this function returns `false`, i.e. the constant approach will be used. """ use_closure(::ADTypes.AbstractADType) = true +use_closure(::ADTypes.AutoEnzyme) = false """ getmodel(f) diff --git a/src/varinfo.jl b/src/varinfo.jl index 20986d1a4..51824844a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -287,7 +287,7 @@ function typed_varinfo(vi::UntypedVarInfo) ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, deepcopy(vi.accs)) + return VarInfo(nt, vi.accs) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -348,7 +348,7 @@ single `VarNamedVector` as its metadata field. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, deepcopy(vi.accs)) + return VarInfo(md, vi.accs) end function untyped_vector_varinfo( rng::Random.AbstractRNG, @@ -391,12 +391,12 @@ NamedTuple of `VarNamedVector`s as its metadata field. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, deepcopy(vi.accs)) + return VarInfo(md, vi.accs) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) nt = NamedTuple(new_metas) - return VarInfo(nt, deepcopy(vi.accs)) + return VarInfo(nt, vi.accs) end function typed_vector_varinfo( rng::Random.AbstractRNG, @@ -447,10 +447,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # The below line is finicky for type stability. For instance, assigning the eltype to # convert to into an intermediate variable makes this unstable (constant propagation) # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), - deepcopy(getaccs(vi)), - ) + accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), vi.accs) return VarInfo(md, accs) end @@ -533,7 +530,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.accs)) + return VarInfo(metadata, varinfo.accs) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -622,7 +619,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, deepcopy(varinfo_right.accs)) + return VarInfo(metadata, varinfo_right.accs) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -1014,7 +1011,7 @@ istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.metadata, accs) # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata)