@@ -59,12 +59,12 @@ function (obj::Kernel{CPU})(args...; ndrange=nothing, workgroupsize=nothing, dep
59
59
ndrange = nothing
60
60
end
61
61
62
- t = Threads. @spawn __run (obj, ndrange, iterspace, args, dependencies)
62
+ t = Threads. @spawn __run (obj, ndrange, iterspace, args, dependencies, Val (dynamic) )
63
63
return CPUEvent (t)
64
64
end
65
65
66
66
# Inference barriers
67
- function __run (obj, ndrange, iterspace, args, dependencies)
67
+ function __run (obj, ndrange, iterspace, args, dependencies, :: Val{dynamic} ) where dynamic
68
68
__waitall (CPU (), dependencies, yield)
69
69
N = length (iterspace)
70
70
Nthreads = Threads. nthreads ()
@@ -79,16 +79,16 @@ function __run(obj, ndrange, iterspace, args, dependencies)
79
79
len, rem = 1 , 0
80
80
end
81
81
if Nthreads == 1
82
- __thread_run (1 , len, rem, obj, ndrange, iterspace, args)
82
+ __thread_run (1 , len, rem, obj, ndrange, iterspace, args, Val (dynamic) )
83
83
else
84
84
@sync for tid in 1 : Nthreads
85
- Threads. @spawn __thread_run (tid, len, rem, obj, ndrange, iterspace, args)
85
+ Threads. @spawn __thread_run (tid, len, rem, obj, ndrange, iterspace, args, Val (dynamic) )
86
86
end
87
87
end
88
88
return nothing
89
89
end
90
90
91
- function __thread_run (tid, len, rem, obj, ndrange, iterspace, args)
91
+ function __thread_run (tid, len, rem, obj, ndrange, iterspace, args, :: Val{dynamic} ) where dynamic
92
92
# compute this thread's iterations
93
93
f = 1 + ((tid- 1 ) * len)
94
94
l = f + len - 1
@@ -105,27 +105,16 @@ function __thread_run(tid, len, rem, obj, ndrange, iterspace, args)
105
105
# run this thread's iterations
106
106
for i = f: l
107
107
block = @inbounds blocks (iterspace)[i]
108
- # TODO : how do we use the information that the iteration space maps perfectly to
109
- # the ndrange without incurring a 2x compilation overhead
110
- # if dynamic
111
- ctx = mkcontextdynamic (obj, block, ndrange, iterspace)
108
+ ctx = mkcontext (obj, block, ndrange, iterspace, Val (dynamic))
112
109
Cassette. overdub (ctx, obj. f, args... )
113
- # else
114
- # ctx = mkcontext(obj, blocks, ndrange, iterspace)
115
- # Threads.@spawn Cassette.overdub(ctx, obj.f, args...)
116
110
end
117
111
return nothing
118
112
end
119
113
120
114
Cassette. @context CPUCtx
121
115
122
- function mkcontext (kernel:: Kernel{CPU} , I, _ndrange, iterspace)
123
- metadata = CompilerMetadata {ndrange(kernel), false} (I, _ndrange, iterspace)
124
- Cassette. disablehooks (CPUCtx (pass = CompilerPass, metadata= metadata))
125
- end
126
-
127
- function mkcontextdynamic (kernel:: Kernel{CPU} , I, _ndrange, iterspace)
128
- metadata = CompilerMetadata {ndrange(kernel), true} (I, _ndrange, iterspace)
116
+ function mkcontext (kernel:: Kernel{CPU} , I, _ndrange, iterspace, :: Val{dynamic} ) where dynamic
117
+ metadata = CompilerMetadata {ndrange(kernel), dynamic} (I, _ndrange, iterspace)
129
118
Cassette. disablehooks (CPUCtx (pass = CompilerPass, metadata= metadata))
130
119
end
131
120
0 commit comments