Skip to content

Commit 27f9fcf

Browse files
authored
Merge pull request #345 from JuliaGPU/tb/metal
Metal simplifications
2 parents 6bdecc9 + 52e98a2 commit 27f9fcf

File tree

4 files changed

+77
-365
lines changed

4 files changed

+77
-365
lines changed

.buildkite/pipeline.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,29 @@ steps:
4343
if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft
4444
timeout_in_minutes: 60
4545

46+
- label: "Metal.jl"
47+
plugins:
48+
- JuliaCI/julia#v1:
49+
version: 1.8
50+
- JuliaCI/julia-coverage#v1:
51+
codecov: true
52+
command: |
53+
julia -e 'using Pkg;
54+
55+
println("--- :julia: Instantiating project");
56+
Pkg.develop(PackageSpec(path=pwd()));
57+
Pkg.add(PackageSpec(name="Metal", rev="main",
58+
url="https://github.com/JuliaGPU/Metal.jl.git"));
59+
Pkg.build();
60+
61+
println("+++ :julia: Running tests");
62+
Pkg.test("Metal"; coverage=true);'
63+
agents:
64+
queue: "juliagpu"
65+
metal: "*"
66+
if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft
67+
timeout_in_minutes: 120
68+
4669
# - label: "AMDGPU.jl"
4770
# plugins:
4871
# - JuliaCI/julia#v1:

src/interface.jl

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,14 @@ struct CompilerJob{T,P,F}
139139
params::P
140140
entry_abi::Symbol
141141

142-
# metadata gathered during compilation
143-
meta::Dict{Symbol,Any}
144-
145-
function CompilerJob(target::AbstractCompilerTarget, source::FunctionSpec,
146-
params::AbstractCompilerParams, entry_abi::Symbol)
142+
function CompilerJob(target::AbstractCompilerTarget, source::FunctionSpec, params::AbstractCompilerParams, entry_abi::Symbol)
147143
if entry_abi (:specfunc, :func)
148144
error("Unknown entry_abi=$entry_abi")
149145
end
150-
new{typeof(target), typeof(params), typeof(source)}(
151-
target, source, params, entry_abi, Dict{Symbol,Any}())
146+
new{typeof(target), typeof(params), typeof(source)}(target, source, params, entry_abi)
152147
end
153148
end
154-
CompilerJob(target::AbstractCompilerTarget, source::FunctionSpec,
155-
params::AbstractCompilerParams; entry_abi=:specfunc) =
149+
CompilerJob(target::AbstractCompilerTarget, source::FunctionSpec, params::AbstractCompilerParams; entry_abi=:specfunc) =
156150
CompilerJob(target, source, params, entry_abi)
157151

158152
Base.similar(@nospecialize(job::CompilerJob), @nospecialize(source::FunctionSpec)) =
@@ -162,20 +156,13 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{T})) where {T}
162156
print(io, "CompilerJob of ", job.source, " for ", T)
163157
end
164158

165-
# make it possible to key on CompilerJobs, while ignoring the metadata
166159
function Base.hash(job::CompilerJob, h::UInt)
167160
h = hash(job.target, h)
168161
h = hash(job.source, h)
169162
h = hash(job.params, h)
170163
h = hash(job.entry_abi, h)
171164
h
172165
end
173-
function Base.isequal(a::CompilerJob, b::CompilerJob)
174-
a.target == b.target &&
175-
a.source == b.source &&
176-
a.params == b.params &&
177-
a.entry_abi == b.entry_abi
178-
end
179166

180167

181168
## contexts

src/irgen.jl

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -312,42 +312,31 @@ end
312312
GHOST # not passed
313313
end
314314

315-
function method_argnames(m::Method)
316-
argnames = ccall(:jl_uncompress_argnames, Vector{Symbol}, (Any,), m.slot_syms)
317-
isempty(argnames) && return argnames
318-
return argnames[1:m.nargs]
319-
end
320-
321315
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType)
322316
source_sig = typed_signature(job)
323317

324318
source_types = [source_sig.parameters...]
325-
source_method = only(method_matches(typed_signature(job); job.source.world)).method
326-
source_arguments = method_argnames(source_method)
327319

328320
codegen_types = parameters(codegen_ft)
329321

330322
args = []
331323
codegen_i = 1
332324
for (source_i, source_typ) in enumerate(source_types)
333-
source_name = source_arguments[min(source_i, length(source_arguments))]
334-
# NOTE: in case of varargs, we have fewer arguments than parameters
335-
336325
if isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
337-
push!(args, (cc=GHOST, typ=source_typ, name=source_name))
326+
push!(args, (cc=GHOST, typ=source_typ))
338327
continue
339328
end
340329

341330
codegen_typ = codegen_types[codegen_i]
342331
if codegen_typ isa LLVM.PointerType && !issized(eltype(codegen_typ))
343-
push!(args, (cc=MUT_REF, typ=source_typ, name=source_name,
332+
push!(args, (cc=MUT_REF, typ=source_typ,
344333
codegen=(typ=codegen_typ, i=codegen_i)))
345334
elseif codegen_typ isa LLVM.PointerType && issized(eltype(codegen_typ)) &&
346335
!(source_typ <: Ptr) && !(source_typ <: Core.LLVMPtr)
347-
push!(args, (cc=BITS_REF, typ=source_typ, name=source_name,
336+
push!(args, (cc=BITS_REF, typ=source_typ,
348337
codegen=(typ=codegen_typ, i=codegen_i)))
349338
else
350-
push!(args, (cc=BITS_VALUE, typ=source_typ, name=source_name,
339+
push!(args, (cc=BITS_VALUE, typ=source_typ,
351340
codegen=(typ=codegen_typ, i=codegen_i)))
352341
end
353342
codegen_i += 1
@@ -356,37 +345,6 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
356345
return args
357346
end
358347

359-
function classify_fields(julia::DataType, llvm::LLVMType)
360-
nfields = fieldcount(julia)
361-
fieldoffsets = [fieldoffset(julia, i) for i in 1:nfields]
362-
fieldsizes = similar(fieldoffsets)
363-
for i in 1:nfields
364-
field_start = fieldoffsets[i]
365-
field_end = i == nfields ? sizeof(julia) : fieldoffsets[i+1]
366-
fieldsizes[i] = field_end - field_start
367-
end
368-
fieldsizes
369-
370-
args = []
371-
codegen_i = 1
372-
for source_i in 1:nfields
373-
source_name = fieldname(julia, source_i)
374-
source_typ = fieldtype(julia, source_i)
375-
if fieldsizes[source_i] == 0
376-
push!(args, (; typ=source_typ, name=source_name))
377-
continue
378-
end
379-
380-
# NOTE: a cc doesn't make sense here, so the lack of codegen field should be checked
381-
382-
codegen_typ = elements(llvm)[codegen_i]
383-
push!(args, (typ=source_typ, name=source_name, codegen=(typ=codegen_typ, i=codegen_i)))
384-
codegen_i += 1
385-
end
386-
387-
return args
388-
end
389-
390348
if VERSION >= v"1.7.0-DEV.204"
391349
function is_immutable_datatype(T::Type)
392350
isa(T,DataType) && !Base.ismutabletype(T)

0 commit comments

Comments
 (0)