@@ -59,32 +59,62 @@ function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dep
59
59
ndrange = nothing
60
60
end
61
61
62
- t = __run (obj, ndrange, iterspace, args, dependencies)
62
+ t = Threads . @spawn __run (obj, ndrange, iterspace, args, dependencies)
63
63
return CPUEvent (t)
64
64
end
65
65
66
- # Inference barrier
66
+ # Inference barriers
67
67
function __run (obj, ndrange, iterspace, args, dependencies)
68
- return Threads. @spawn begin
69
- __waitall (CPU (), dependencies, yield)
70
- @sync begin
71
- # TODO : how do we use the information that the iteration space maps perfectly to
72
- # the ndrange without incurring a 2x compilation overhead
73
- # if dynamic
74
- for block in iterspace
75
- let ctx = mkcontextdynamic (obj, block, ndrange, iterspace)
76
- Threads. @spawn Cassette. overdub (ctx, obj. f, args... )
77
- end
78
- end
79
- # else
80
- # for block in iterspace
81
- # let ctx = mkcontext(obj, blocks, ndrange, iterspace)
82
- # Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
83
- # end
84
- # end
85
- # end
68
+ __waitall (CPU (), dependencies, yield)
69
+ N = length (iterspace)
70
+ Nthreads = Threads. nthreads ()
71
+ if Nthreads == 1
72
+ len, rem = N, 0
73
+ else
74
+ len, rem = divrem (N, Nthreads)
75
+ end
76
+ # not enough iterations for all the threads?
77
+ if len == 0
78
+ Nthreads = N
79
+ len, rem = 1 , 0
80
+ end
81
+ if Nthreads == 1
82
+ __thread_run (1 , len, rem, obj, ndrange, iterspace, args)
83
+ else
84
+ @sync for tid in 1 : Nthreads
85
+ Threads. @spawn __thread_run (tid, len, rem, obj, ndrange, iterspace, args)
86
+ end
87
+ end
88
+ return nothing
89
+ end
90
+
91
+ function __thread_run (tid, len, rem, obj, ndrange, iterspace, args)
92
+ # compute this thread's iterations
93
+ f = 1 + ((tid- 1 ) * len)
94
+ l = f + len - 1
95
+ # distribute remaining iterations evenly
96
+ if rem > 0
97
+ if tid <= rem
98
+ f = f + (tid- 1 )
99
+ l = l + tid
100
+ else
101
+ f = f + rem
102
+ l = l + rem
86
103
end
87
104
end
105
+ # run this thread's iterations
106
+ for i = f: l
107
+ block = @inbounds blocks (iterspace)[i]
108
+ # TODO : how do we use the information that the iteration space maps perfectly to
109
+ # the ndrange without incurring a 2x compilation overhead
110
+ # if dynamic
111
+ ctx = mkcontextdynamic (obj, block, ndrange, iterspace)
112
+ Cassette. overdub (ctx, obj. f, args... )
113
+ # else
114
+ # ctx = mkcontext(obj, blocks, ndrange, iterspace)
115
+ # Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
116
+ end
117
+ return nothing
88
118
end
89
119
90
120
Cassette. @context CPUCtx
0 commit comments