Skip to content

Commit d2c3f2a

Browse files
committed
add and benchmark typed_hvcat(SA, ::Val, ...)
to explore the benefits of JuliaLang/julia#36719
1 parent 9f2fa89 commit d2c3f2a

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

perf/hvcat_val.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using StaticArray, BenchmarkTools
2+
3+
let
4+
rows, cols = 4, 4
5+
_dims = Expr(:tuple, [cols for _ in 1:rows]...)
6+
7+
for (f, wrap_val) in [(:f1, false), (:f2, true)]
8+
dims = wrap_val ? :(Val{$_dims}()) : _dims
9+
zeros_sa = :(Base.typed_hvcat(SA, $dims, $([0 for _ in 1:rows*cols]...)))
10+
xs = [Symbol(:x, i) for i in 1:rows*cols]
11+
is = [Symbol(:i, i) for i in 1:rows*cols]
12+
is_sa = :(Base.typed_hvcat(SA, $dims, $(is...)))
13+
@eval begin
14+
function $f($(xs...))
15+
r = $zeros_sa
16+
for ($(is...),) in Iterators.product($(xs...))
17+
r += $is_sa
18+
end
19+
r
20+
end
21+
end
22+
end
23+
24+
xs = [:(1:2) for _ in 1:rows*cols]
25+
display(@eval @benchmark f1($(xs...)))
26+
display(@eval @benchmark f2($(xs...)))
27+
end

src/initializers.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,17 @@ end
5959
@inline Base.typed_hvcat(sa::Type{SA}, rows::Dims, xs::Number...) = _SA_typed_hvcat(sa, rows, xs)
6060
@inline Base.typed_hvcat(sa::Type{SA{T}}, rows::Dims, xs::Number...) where T = _SA_typed_hvcat(sa, rows, xs)
6161

62+
@generated function _SA_typed_hvcat(::Type{sa}, ::Val{rows}, xs) where {sa,rows}
63+
M = rows[1]
64+
if any(r->r != M, rows)
65+
# @pure may not throw... probably. See
66+
# https://discourse.julialang.org/t/can-pure-functions-throw-an-error/18459
67+
return :(throw(ArgumentError("SA[...] matrix rows of length $_rows are inconsistent")))
68+
end
69+
msize = Size(M, length(rows))
70+
# hvcat lowering is row major ordering, so we must transpose
71+
:(Base.@_inline_meta; transpose($(similar_type(sa, msize))(xs)))
72+
end
73+
74+
@inline Base.typed_hvcat(sa::Type{SA}, rows::Val{_rows}, xs::Number...) where {_rows} = _SA_typed_hvcat(sa, rows, xs)
75+
@inline Base.typed_hvcat(sa::Type{SA{T}}, rows::Val{_rows}, xs::Number...) where {T,_rows} = _SA_typed_hvcat(sa, rows, xs)

0 commit comments

Comments
 (0)