Skip to content

Don't deepcopy accs #948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@
the constant approach will be used.
"""
use_closure(::ADTypes.AbstractADType) = true
use_closure(::ADTypes.AutoEnzyme) = false

Check warning on line 308 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L308

Added line #L308 was not covered by tests

"""
getmodel(f)
Expand Down
19 changes: 8 additions & 11 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@
)
end
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
return VarInfo(nt, deepcopy(vi.accs))
return VarInfo(nt, vi.accs)
end
function typed_varinfo(vi::NTVarInfo)
# This function preserves the behaviour of typed_varinfo(vi) where vi is
Expand Down Expand Up @@ -348,7 +348,7 @@
"""
function untyped_vector_varinfo(vi::UntypedVarInfo)
md = metadata_to_varnamedvector(vi.metadata)
return VarInfo(md, deepcopy(vi.accs))
return VarInfo(md, vi.accs)
end
function untyped_vector_varinfo(
rng::Random.AbstractRNG,
Expand Down Expand Up @@ -391,12 +391,12 @@
"""
function typed_vector_varinfo(vi::NTVarInfo)
md = map(metadata_to_varnamedvector, vi.metadata)
return VarInfo(md, deepcopy(vi.accs))
return VarInfo(md, vi.accs)

Check warning on line 394 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L394

Added line #L394 was not covered by tests
end
function typed_vector_varinfo(vi::UntypedVectorVarInfo)
new_metas = group_by_symbol(vi.metadata)
nt = NamedTuple(new_metas)
return VarInfo(nt, deepcopy(vi.accs))
return VarInfo(nt, vi.accs)
end
function typed_vector_varinfo(
rng::Random.AbstractRNG,
Expand Down Expand Up @@ -447,10 +447,7 @@
# The below line is finicky for type stability. For instance, assigning the eltype to
# convert to into an intermediate variable makes this unstable (constant propagation)
# fails. Take care when editing.
accs = map(
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc),
deepcopy(getaccs(vi)),
)
accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), vi.accs)
return VarInfo(md, accs)
end

Expand Down Expand Up @@ -533,7 +530,7 @@

function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
metadata = subset(varinfo.metadata, vns)
return VarInfo(metadata, deepcopy(varinfo.accs))
return VarInfo(metadata, varinfo.accs)
end

function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
Expand Down Expand Up @@ -622,7 +619,7 @@

function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
return VarInfo(metadata, deepcopy(varinfo_right.accs))
return VarInfo(metadata, varinfo_right.accs)
end

function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
Expand Down Expand Up @@ -1014,7 +1011,7 @@
istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans")

getaccs(vi::VarInfo) = vi.accs
setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs
setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.metadata, accs)

# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple).
isempty(vi::VarInfo) = _isempty(vi.metadata)
Expand Down
Loading