Skip to content

Commit 027ccdc

Browse files
committed
Change length of ThreadSafeVarInfo logps to maxthreadid()
1 parent cac8f9c commit 027ccdc

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Added compatibility with DifferentiationInterface.jl 0.7, and also with JET.jl 0
66

77
The JET compatibility entry should only affect you if you are using DynamicPPL on the Julia 1.12 pre-release.
88

9+
The array of log probabilities stored in `ThreadSafeVarInfo` is now of length `Threads.maxthreadid()`, rather than `Threads.nthreads()`.
10+
911
## 0.36.3
1012

1113
Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo.

src/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
99
logps::L
1010
end
1111
function ThreadSafeVarInfo(vi::AbstractVarInfo)
12-
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
12+
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.maxthreadid()])
1313
end
1414
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1515

test/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@test threadsafe_vi.varinfo === vi
77
@test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))}
8-
@test length(threadsafe_vi.logps) == Threads.nthreads()
8+
@test length(threadsafe_vi.logps) == Threads.maxthreadid()
99
@test all(iszero(x[]) for x in threadsafe_vi.logps)
1010
end
1111

0 commit comments

Comments
 (0)