Skip to content

Commit dc5b924

Browse files
authored
Merge pull request #57 from JuliaGPU/vc/cpu_perf
only create as many tasks as threads and more inference barriers
2 parents 70cf00f + 7350da1 commit dc5b924

File tree

1 file changed

+50
-20
lines changed

1 file changed

+50
-20
lines changed

src/backends/cpu.jl

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,32 +59,62 @@ function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dep
5959
ndrange = nothing
6060
end
6161

62-
t = __run(obj, ndrange, iterspace, args, dependencies)
62+
t = Threads.@spawn __run(obj, ndrange, iterspace, args, dependencies)
6363
return CPUEvent(t)
6464
end
6565

66-
# Inference barrier
66+
# Inference barriers
6767
function __run(obj, ndrange, iterspace, args, dependencies)
68-
return Threads.@spawn begin
69-
__waitall(CPU(), dependencies, yield)
70-
@sync begin
71-
# TODO: how do we use the information that the iteration space maps perfectly to
72-
# the ndrange without incurring a 2x compilation overhead
73-
# if dynamic
74-
for block in iterspace
75-
let ctx = mkcontextdynamic(obj, block, ndrange, iterspace)
76-
Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
77-
end
78-
end
79-
# else
80-
# for block in iterspace
81-
# let ctx = mkcontext(obj, blocks, ndrange, iterspace)
82-
# Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
83-
# end
84-
# end
85-
# end
68+
__waitall(CPU(), dependencies, yield)
69+
N = length(iterspace)
70+
Nthreads = Threads.nthreads()
71+
if Nthreads == 1
72+
len, rem = N, 0
73+
else
74+
len, rem = divrem(N, Nthreads)
75+
end
76+
# not enough iterations for all the threads?
77+
if len == 0
78+
Nthreads = N
79+
len, rem = 1, 0
80+
end
81+
if Nthreads == 1
82+
__thread_run(1, len, rem, obj, ndrange, iterspace, args)
83+
else
84+
@sync for tid in 1:Nthreads
85+
Threads.@spawn __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
86+
end
87+
end
88+
return nothing
89+
end
90+
91+
function __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
92+
# compute this thread's iterations
93+
f = 1 + ((tid-1) * len)
94+
l = f + len - 1
95+
# distribute remaining iterations evenly
96+
if rem > 0
97+
if tid <= rem
98+
f = f + (tid-1)
99+
l = l + tid
100+
else
101+
f = f + rem
102+
l = l + rem
86103
end
87104
end
105+
# run this thread's iterations
106+
for i = f:l
107+
block = @inbounds blocks(iterspace)[i]
108+
# TODO: how do we use the information that the iteration space maps perfectly to
109+
# the ndrange without incurring a 2x compilation overhead
110+
# if dynamic
111+
ctx = mkcontextdynamic(obj, block, ndrange, iterspace)
112+
Cassette.overdub(ctx, obj.f, args...)
113+
# else
114+
# ctx = mkcontext(obj, blocks, ndrange, iterspace)
115+
# Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
116+
end
117+
return nothing
88118
end
89119

90120
Cassette.@context CPUCtx

0 commit comments

Comments
 (0)