Skip to content

Commit 810c0a0

Browse files
authored
rework type parameter stripping (#78)
Fixes #69
1 parent 10434f3 commit 810c0a0

File tree

1 file changed

+51
-20
lines changed

1 file changed

+51
-20
lines changed

src/FixedSizeArray.jl

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ end
6868
Base.size(a::FixedSizeArray) = a.size
6969

7070
function Base.similar(::T, ::Type{E}, size::NTuple{N,Int}) where {T<:FixedSizeArray,E,N}
71-
with_replaced_parameters(DenseArray, T, Val(E), Val(N))(undef, size)
71+
spec = TypeParametersElementTypeAndDimensionality()
72+
S = val_parameter(with_stripped_type_parameters(spec, T)){E, N}
73+
S(undef, size)
7274
end
7375

7476
Base.isassigned(a::FixedSizeArray, i::Int) = isassigned(a.mem, i)
@@ -106,7 +108,8 @@ end
106108
# broadcasting
107109

108110
function Base.BroadcastStyle(::Type{T}) where {T<:FixedSizeArray}
109-
Broadcast.ArrayStyle{stripped_type(DenseArray, T)}()
111+
spec = TypeParametersElementTypeAndDimensionality()
112+
Broadcast.ArrayStyle{val_parameter(with_stripped_type_parameters(spec, T))}()
110113
end
111114

112115
function Base.similar(
@@ -118,30 +121,58 @@ end
118121

119122
# helper functions
120123

121-
normalized_type(::Type{T}) where {T} = T
124+
val_parameter(::Val{P}) where {P} = P
122125

123-
function stripped_type_unchecked(::Type{DenseVector}, ::Type{<:GenericMemory{K,<:Any,AS}}) where {K,AS}
124-
GenericMemory{K,<:Any,AS}
125-
end
126+
struct TypeParametersElementType end
127+
struct TypeParametersElementTypeAndDimensionality end
128+
129+
"""
130+
with_stripped_type_parameters_unchecked(spec, t::Type)::Val{s}
126131
127-
Base.@assume_effects :consistent function stripped_type_unchecked(
128-
::Type{DenseArray}, ::Type{<:FixedSizeArray{<:Any,<:Any,V}},
129-
) where {V}
130-
U = stripped_type(DenseVector, V)
131-
FixedSizeArray{E,N,U{E}} where {E,N}
132+
An implementation detail of [`with_stripped_type_parameters`](@ref). Don't call
133+
directly.
134+
"""
135+
function with_stripped_type_parameters_unchecked end
136+
137+
function with_stripped_type_parameters_unchecked(::TypeParametersElementType, ::Type{<:(GenericMemory{K, T, AS} where {T})}) where {K, AS}
138+
s = GenericMemory{K, T, AS} where {T}
139+
Val{s}()
132140
end
133141

134-
function stripped_type(::Type{T}, ::Type{S}) where {T,S<:T}
135-
ret = stripped_type_unchecked(T, S)::Type{<:T}::UnionAll
136-
S::Type{<:ret}
137-
normalized_type(ret) # ensure `UnionAll` type variable order is normalized
142+
# `Base.@assume_effects :consistent` is a workaround for:
143+
# https://github.com/JuliaLang/julia/issues/56966
144+
Base.@assume_effects :consistent function with_stripped_type_parameters_unchecked(::TypeParametersElementTypeAndDimensionality, ::Type{<:(FixedSizeArray{T, N, Mem} where {T, N})}) where {Mem}
145+
spec_mem = TypeParametersElementType()
146+
mem_v = with_stripped_type_parameters(spec_mem, Mem)
147+
mem = val_parameter(mem_v)
148+
s = FixedSizeArray{T, N, mem{T}} where {T, N}
149+
Val{s}()
138150
end
139151

140-
function with_replaced_parameters(::Type{T}, ::Type{S}, ::Val{P1}, ::Val{P2}) where {T,S<:T,P1,P2}
141-
t = T{P1,P2}::Type{<:T}
142-
s = stripped_type(T, S)
143-
S::Type{<:s}
144-
s{P1,P2}::Type{<:s}::Type{<:T}::Type{<:t}
152+
"""
153+
with_stripped_type_parameters(spec, t::Type)::Val{s}
154+
155+
The type `s` is a `UnionAll` supertype of `t`:
156+
157+
```julia
158+
(s isa UnionAll) && (t <: s)
159+
```
160+
161+
Furthermore, `s` has type variables in place of the type parameters specified
162+
via `spec`.
163+
164+
NB: `Val{s}()` is returned instead of `s` so the method would be *consistent*
165+
from the point of view of Julia's effect inference, enabling constant folding.
166+
167+
NB: this function is supposed to only have the one method. To add
168+
functionality, add methods to [`with_stripped_type_parameters_unchecked`](@ref).
169+
"""
170+
function with_stripped_type_parameters(spec, t::Type)
171+
ret = with_stripped_type_parameters_unchecked(spec, t)
172+
s = val_parameter(ret)
173+
s = s::UnionAll
174+
s = s::(Type{T} where {T>:t})
175+
Val{s}()
145176
end
146177

147178
dimension_count_of(::Base.SizeUnknown) = 1

0 commit comments

Comments
 (0)