Skip to content

Commit 32ca651

Browse files
authored
Merge pull request #218 from S-D-R/constant_constructors
Constructors for ConstantStruct and multidimensional ConstantArrays
2 parents b2f9c5f + f71fdaf commit 32ca651

File tree

5 files changed

+308
-73
lines changed

5 files changed

+308
-73
lines changed

COVERAGE.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,10 @@ Core
432432
- [ ] LLVMConstString
433433
- [ ] LLVMIsConstantString
434434
- [ ] LLVMGetAsString
435-
- [ ] LLVMConstStructInContext
436-
- [ ] LLVMConstStruct
437-
- [ ] LLVMConstArray
438-
- [ ] LLVMConstNamedStruct
435+
- [x] LLVMConstStructInContext
436+
- [x] LLVMConstStruct
437+
- [x] LLVMConstArray
438+
- [x] LLVMConstNamedStruct
439439
- [ ] LLVMGetElementAsConstant
440440
- [ ] LLVMConstVector
441441

src/core/module.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,12 @@ set_used!(mod::Module, values::GlobalVariable...) = nothing
7777
set_compiler_used!(mod::Module, values::GlobalVariable...) = nothing
7878
end
7979

80+
8081
## type iteration
8182

8283
export types
8384

84-
struct ModuleTypeDict <: AbstractDict{String,LLVMType}
85-
mod::Module
86-
end
87-
88-
types(mod::Module) = ModuleTypeDict(mod)
89-
90-
function Base.haskey(iter::ModuleTypeDict, name::String)
91-
return API.LLVMGetTypeByName(iter.mod, name) != C_NULL
92-
end
93-
94-
function Base.getindex(iter::ModuleTypeDict, name::String)
95-
objref = API.LLVMGetTypeByName(iter.mod, name)
96-
objref == C_NULL && throw(KeyError(name))
97-
return LLVMType(objref)
98-
end
85+
@deprecate types(mod::Module) types(context(mod))
9986

10087

10188
## metadata iteration

src/core/type.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,31 @@ identify(::Type{LLVMType}, ::Val{API.LLVMTokenTypeKind}) = TokenType
272272

273273
TokenType(ctx::Context) =
274274
TokenType(API.LLVMTokenTypeInContext(ctx))
275+
276+
277+
## type iteration
278+
279+
export types
280+
281+
struct ContextTypeDict <: AbstractDict{String,LLVMType}
282+
ctx::Context
283+
end
284+
285+
# FIXME: remove on LLVM 12
286+
function LLVMGetTypeByName2(ctx::Context, name)
287+
Module("dummy", ctx) do mod
288+
API.LLVMGetTypeByName(mod, name)
289+
end
290+
end
291+
292+
types(ctx::Context) = ContextTypeDict(ctx)
293+
294+
function Base.haskey(iter::ContextTypeDict, name::String)
295+
return LLVMGetTypeByName2(iter.ctx, name) != C_NULL
296+
end
297+
298+
function Base.getindex(iter::ContextTypeDict, name::String)
299+
objref = LLVMGetTypeByName2(iter.ctx, name)
300+
objref == C_NULL && throw(KeyError(name))
301+
return LLVMType(objref)
302+
end

src/core/value/constant.jl

Lines changed: 153 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ const WideInteger = Union{Int64, UInt64}
4646
ConstantInt(typ::IntegerType, val::WideInteger, signed=false) =
4747
ConstantInt(API.LLVMConstInt(typ, reinterpret(Culonglong, val),
4848
convert(Bool, signed)))
49-
const SmallInteger = Union{Int8, Int16, Int32, UInt8, UInt16, UInt32}
49+
const SmallInteger = Union{Core.Bool, Int8, Int16, Int32, UInt8, UInt16, UInt32}
5050
ConstantInt(typ::IntegerType, val::SmallInteger, signed=false) =
5151
ConstantInt(typ, convert(Int64, val), signed)
5252

5353
function ConstantInt(typ::IntegerType, val::Integer, signed=false)
54-
valbits = ceil(Int, log2(abs(val))) + 1
54+
valbits = ceil(Int, log2(abs(val))) + 1 # FIXME: doesn't work for val=0
5555
numwords = ceil(Int, valbits / 64)
5656
words = Vector{Culonglong}(undef, numwords)
5757
for i in 1:numwords
@@ -67,12 +67,18 @@ function ConstantInt(val::T, ctx::Context) where T<:SizeableInteger
6767
return ConstantInt(typ, val, T<:Signed)
6868
end
6969

70+
# Booleans are encoded with a single bit, so we can't use sizeof
71+
ConstantInt(val::Core.Bool, ctx::Context) = ConstantInt(Int1Type(ctx), val ? 1 : 0)
72+
7073
Base.convert(::Type{T}, val::ConstantInt) where {T<:Unsigned} =
7174
convert(T, API.LLVMConstIntGetZExtValue(val))
7275

7376
Base.convert(::Type{T}, val::ConstantInt) where {T<:Signed} =
7477
convert(T, API.LLVMConstIntGetSExtValue(val))
7578

79+
# Booleans aren't Signed or Unsigned
80+
Base.convert(::Type{Core.Bool}, val::ConstantInt) = convert(Int, val) != 0
81+
7682

7783
@checked struct ConstantFP <: Constant
7884
ref::API.LLVMValueRef
@@ -93,7 +99,7 @@ Base.convert(::Type{T}, val::ConstantFP) where {T<:AbstractFloat} =
9399
convert(T, API.LLVMConstRealGetDouble(val, Ref{API.LLVMBool}()))
94100

95101

96-
## aggregate
102+
## aggregate zero
97103

98104
export ConstantAggregateZero
99105

@@ -102,53 +108,174 @@ export ConstantAggregateZero
102108
end
103109
identify(::Type{Value}, ::Val{API.LLVMConstantAggregateZeroValueKind}) = ConstantAggregateZero
104110

105-
# there currently seems to be no function in the LLVM-C interface which returns a
106-
# ConstantAggregateZero value directly, but values can occur through calls to LLVMConstNull
107-
111+
# array interface
112+
# FIXME: can we reuse the ::ConstantArray functionality with ConstantAggregateZero values?
113+
# probably works fine if we just get rid of the refcheck
114+
Base.eltype(caz::ConstantAggregateZero) = eltype(llvmtype(caz))
115+
Base.size(caz::ConstantAggregateZero) = (0,)
116+
Base.length(caz::ConstantAggregateZero) = 0
117+
Base.axes(caz::ConstantAggregateZero) = (Base.OneTo(0),)
118+
Base.collect(caz::ConstantAggregateZero) = Value[]
108119

109-
## constant expressions
110120

111-
export ConstantExpr, ConstantAggregate, ConstantArray, ConstantStruct, ConstantVector, InlineAsm
112-
113-
@checked struct ConstantExpr <: Constant
114-
ref::API.LLVMValueRef
115-
end
116-
identify(::Type{Value}, ::Val{API.LLVMConstantExprValueKind}) = ConstantExpr
121+
## regular aggregate
117122

118123
abstract type ConstantAggregate <: Constant end
119124

125+
# arrays
126+
120127
@checked struct ConstantArray <: ConstantAggregate
121128
ref::API.LLVMValueRef
122129
end
123130
identify(::Type{Value}, ::Val{API.LLVMConstantArrayValueKind}) = ConstantArray
124131
identify(::Type{Value}, ::Val{API.LLVMConstantDataArrayValueKind}) = ConstantArray
125132

126-
ConstantArray(typ::LLVMType, data::Vector{T}) where {T<:Constant} =
127-
ConstantArray(API.LLVMConstArray(typ, data, length(data)))
128-
ConstantArray(typ::IntegerType, data::Vector{T}) where {T<:Integer} =
129-
ConstantArray(typ, map(x->ConstantInt(convert(T,x),context(typ)), data))
130-
ConstantArray(typ::FloatingPointType, data::Vector{T}) where {T<:AbstractFloat} =
131-
ConstantArray(typ, map(x->ConstantFP(convert(T,x),context(typ)), data))
133+
ConstantArrayOrAggregateZero(value) = Value(value)::Union{ConstantArray,ConstantAggregateZero}
132134

133-
Base.getindex(ca::ConstantArray, idx::Integer) =
134-
API.LLVMGetElementAsConstant(ca, idx-1)
135-
Base.length(ca::ConstantArray) = length(llvmtype(ca))
135+
# generic constructor taking an array of constants
136+
function ConstantArray(typ::LLVMType, data::AbstractArray{T,N}=T[]) where {T<:Constant,N}
137+
@assert all(x->x==typ, llvmtype.(data))
138+
139+
if N == 1
140+
return ConstantArrayOrAggregateZero(API.LLVMConstArray(typ, Array(data), length(data)))
141+
end
142+
143+
if VERSION >= v"1.1"
144+
ca_vec = map(x->ConstantArray(typ, x), eachslice(data, dims=1))
145+
else
146+
ca_vec = map(x->ConstantArray(typ, x), (view(data, i, ntuple(d->(:), N-1)...) for i in axes(data, 1)))
147+
end
148+
ca_typ = llvmtype(first(ca_vec))
149+
150+
return ConstantArray(API.LLVMConstArray(ca_typ, ca_vec, length(ca_vec)))
151+
end
152+
153+
# shorthands with arrays of plain Julia data
154+
# FIXME: duplicates the ConstantInt/ConstantFP conversion rules
155+
ConstantArray(data::AbstractArray{T,N}, ctx::Context=GlobalContext()) where {T<:Integer,N} =
156+
ConstantArray(IntType(sizeof(T)*8, ctx), ConstantInt.(data, Ref(ctx)))
157+
ConstantArray(data::AbstractArray{Core.Bool,N}, ctx::Context=GlobalContext()) where {N} =
158+
ConstantArray(Int1Type(ctx), ConstantInt.(data, Ref(ctx)))
159+
ConstantArray(data::AbstractArray{Float16,N}, ctx::Context=GlobalContext()) where {N} =
160+
ConstantArray(HalfType(ctx), ConstantFP.(data, Ref(ctx)))
161+
ConstantArray(data::AbstractArray{Float32,N}, ctx::Context=GlobalContext()) where {N} =
162+
ConstantArray(FloatType(ctx), ConstantFP.(data, Ref(ctx)))
163+
ConstantArray(data::AbstractArray{Float64,N}, ctx::Context=GlobalContext()) where {N} =
164+
ConstantArray(DoubleType(ctx), ConstantFP.(data, Ref(ctx)))
165+
166+
# convert back to known array types
167+
function Base.collect(ca::ConstantArray)
168+
constants = Array{Value}(undef, size(ca))
169+
for I in CartesianIndices(size(ca))
170+
@inbounds constants[I] = ca[Tuple(I)...]
171+
end
172+
return constants
173+
end
174+
175+
# array interface
136176
Base.eltype(ca::ConstantArray) = eltype(llvmtype(ca))
137-
Base.convert(::Type{Array{T,1}}, ca::ConstantArray) where {T<:Integer} =
138-
[convert(T,ConstantInt(ca[i])) for i in 1:length(ca)]
139-
Base.convert(::Type{Array{T,1}}, ca::ConstantArray) where {T<:AbstractFloat} =
140-
[convert(T,ConstantFP(ca[i])) for i in 1:length(ca)]
177+
function Base.size(ca::ConstantArray)
178+
dims = Int[]
179+
typ = llvmtype(ca)
180+
while typ isa ArrayType
181+
push!(dims, length(typ))
182+
typ = eltype(typ)
183+
end
184+
return Tuple(dims)
185+
end
186+
Base.length(ca::ConstantArray) = prod(size(ca))
187+
Base.axes(ca::ConstantArray) = Base.OneTo.(size(ca))
188+
189+
function Base.getindex(ca::ConstantArray, idx::Integer...)
190+
# multidimensional arrays are represented by arrays of arrays,
191+
# which we need to 'peel back' by looking at the operand sets.
192+
# for the final dimension, we use LLVMGetElementAsConstant
193+
@boundscheck Base.checkbounds_indices(Base.Bool, axes(ca), idx) ||
194+
throw(BoundsError(ca, idx))
195+
I = CartesianIndices(size(ca))[idx...]
196+
for i in Tuple(I)
197+
if isempty(operands(ca))
198+
ca = LLVM.Value(API.LLVMGetElementAsConstant(ca, i-1))
199+
else
200+
ca = (Base.@_propagate_inbounds_meta; operands(ca)[i])
201+
end
202+
end
203+
return ca
204+
end
205+
206+
# structs
141207

142208
@checked struct ConstantStruct <: ConstantAggregate
143209
ref::API.LLVMValueRef
144210
end
145211
identify(::Type{Value}, ::Val{API.LLVMConstantStructValueKind}) = ConstantStruct
146212

213+
ConstantStructOrAggregateZero(value) = Value(value)::Union{ConstantStruct,ConstantAggregateZero}
214+
215+
# anonymous
216+
ConstantStruct(values::Vector{<:Constant}; packed::Core.Bool=false) =
217+
ConstantStructOrAggregateZero(API.LLVMConstStruct(values, length(values), convert(Bool, packed)))
218+
ConstantStruct(values::Vector{<:Constant}, ctx::Context; packed::Core.Bool=false) =
219+
ConstantStructOrAggregateZero(API.LLVMConstStructInContext(ctx, values, length(values), convert(Bool, packed)))
220+
221+
# named
222+
ConstantStruct(typ::StructType, values::Vector{<:Constant}) =
223+
ConstantStructOrAggregateZero(API.LLVMConstNamedStruct(typ, values, length(values)))
224+
225+
# create a ConstantStruct from a Julia object
226+
function ConstantStruct(value::T, ctx::Context=GlobalContext(); name=String(nameof(T)),
227+
anonymous::Core.Bool=false, packed::Core.Bool=false) where {T}
228+
isbitstype(T) || throw(ArgumentError("Can only create a ConstantStruct from an isbits struct"))
229+
isprimitivetype(T) && throw(ArgumentError("Cannot create a ConstantStruct from a primitive value"))
230+
231+
constants = Vector{Constant}()
232+
for fieldname in fieldnames(T)
233+
field = getfield(value, fieldname)
234+
235+
if isa(field, Integer)
236+
push!(constants, ConstantInt(field, ctx))
237+
elseif isa(field, AbstractFloat)
238+
push!(constants, ConstantFP(field, ctx))
239+
else # TODO: nested structs?
240+
throw(ArgumentError("only structs with boolean, integer and floating point fields are allowed"))
241+
end
242+
end
243+
244+
if anonymous
245+
ConstantStruct(constants, ctx; packed=packed)
246+
elseif haskey(types(ctx), name)
247+
typ = types(ctx)[name]
248+
if collect(elements(typ)) != llvmtype.(constants)
249+
throw(ArgumentError("Cannot create struct $name {$(join(llvmtype.(constants), ", "))} as it is already defined in this context as {$(join(elements(typ), ", "))}."))
250+
end
251+
ConstantStruct(typ, constants)
252+
else
253+
typ = StructType(name, ctx)
254+
elements!(typ, llvmtype.(constants))
255+
ConstantStruct(typ, constants)
256+
end
257+
end
258+
259+
# vectors
260+
147261
@checked struct ConstantVector <: ConstantAggregate
148262
ref::API.LLVMValueRef
149263
end
150264
identify(::Type{Value}, ::Val{API.LLVMConstantVectorValueKind}) = ConstantVector
151265

266+
267+
## constant expressions
268+
269+
export ConstantExpr, ConstantAggregate, ConstantArray, ConstantStruct, ConstantVector, InlineAsm
270+
271+
@checked struct ConstantExpr <: Constant
272+
ref::API.LLVMValueRef
273+
end
274+
identify(::Type{Value}, ::Val{API.LLVMConstantExprValueKind}) = ConstantExpr
275+
276+
277+
## inline assembly
278+
152279
@checked struct InlineAsm <: Constant
153280
ref::API.LLVMValueRef
154281
end

0 commit comments

Comments
 (0)