Skip to content

Commit a34cb17

Browse files
N5N3mateuszbaran
andauthored
Extend @SArray (nested cat, 1.7 syntax) (#1009)
* Suport nested cat in `@SArray` Use `cat_any` to "cat" all arguments. (No promotion) And better performance (3x faster) Code clean for macro. * Also extend `@SMatix` and `@SVector` Just check the output's shape: 1. Alow missing dimension (Vector isa `n*1` Matrix) 2. And addition size-1 dimension (`m*n*1` Array isa `m*n` Matrix) * Support `MArray` Code reuse. * Some behavior change 1. [1;2] isa Vector 2. [f(...) for ...] has no dim limit. * Add more test * Mark `@SVector [;]` as broken. The constructor is missing for empty `SVector`. Would be fixed in future PRs. Thus just mark it as broken. * Add support to `SA[1;;1]` * Resolve comments: 1. Only test `@SArray [;;]` on master. 2. code clean. 3. support `@SArray fill(1)` `@SArray zeros()` Co-Authored-By: Mateusz Baran <2551062+mateuszbaran@users.noreply.github.com> * Add doc string. * Stop nested cat when meet `[a,b,c]` `a`, `b`, `c` should not be catted. * bump Co-authored-by: Mateusz Baran <2551062+mateuszbaran@users.noreply.github.com>
1 parent f2ac2e6 commit a34cb17

17 files changed

+448
-602
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.4.2"
3+
version = "1.4.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/MArray.jl

Lines changed: 10 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -110,160 +110,17 @@ Base.dataids(ma::MArray) = (UInt(pointer(ma)),)
110110
Base.unsafe_convert(Ptr{T}, pointer_from_objref(a))
111111
end
112112

113-
macro MArray(ex)
114-
if !isa(ex, Expr)
115-
error("Bad input for @MArray")
116-
end
117-
118-
if ex.head == :vect # vector
119-
return esc(Expr(:call, MArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...)))
120-
elseif ex.head == :ref # typed, vector
121-
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
122-
elseif ex.head == :hcat # 1 x n
123-
s1 = 1
124-
s2 = length(ex.args)
125-
return esc(Expr(:call, MArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...)))
126-
elseif ex.head == :typed_hcat # typed, 1 x n
127-
s1 = 1
128-
s2 = length(ex.args) - 1
129-
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
130-
elseif ex.head == :vcat
131-
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m
132-
# Validate
133-
s1 = length(ex.args)
134-
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1)
135-
s2 = minimum(s2s)
136-
if maximum(s2s) != s2
137-
throw(ArgumentError("Rows must be of matching lengths"))
138-
end
139-
140-
exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2]
141-
return esc(Expr(:call, MArray{Tuple{s1, s2}}, Expr(:tuple, exprs...)))
142-
else # n x 1
143-
return esc(Expr(:call, MArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...)))
144-
end
145-
elseif ex.head == :typed_vcat
146-
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m
147-
# Validate
148-
s1 = length(ex.args) - 1
149-
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1)
150-
s2 = minimum(s2s)
151-
if maximum(s2s) != s2
152-
error("Rows must be of matching lengths")
153-
end
154-
155-
exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
156-
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...)))
157-
else # typed, n x 1
158-
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
159-
end
160-
elseif isa(ex, Expr) && ex.head == :comprehension
161-
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
162-
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
163-
end
164-
ex = ex.args[1]
165-
n_rng = length(ex.args) - 1
166-
rng_args = [ex.args[i+1].args[1] for i = 1:n_rng]
167-
rngs = [Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng]
168-
rng_lengths = map(length, rngs)
169-
170-
f = gensym()
171-
f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1])))
172-
173-
# TODO figure out a generic way of doing this...
174-
if n_rng == 1
175-
exprs = [:($f($j1)) for j1 in rngs[1]]
176-
elseif n_rng == 2
177-
exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]]
178-
elseif n_rng == 3
179-
exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]]
180-
elseif n_rng == 4
181-
exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]]
182-
elseif n_rng == 5
183-
exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]]
184-
elseif n_rng == 6
185-
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]]
186-
elseif n_rng == 7
187-
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]]
188-
elseif n_rng == 8
189-
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]]
190-
else
191-
error("@MArray only supports up to 8-dimensional comprehensions")
192-
end
193-
194-
return quote
195-
$(esc(f_expr))
196-
$(esc(Expr(:call, Expr(:curly, :MArray, Tuple{rng_lengths...}), Expr(:tuple, exprs...))))
197-
end
198-
elseif isa(ex, Expr) && ex.head == :typed_comprehension
199-
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
200-
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
201-
end
202-
T = ex.args[1]
203-
ex = ex.args[2]
204-
n_rng = length(ex.args) - 1
205-
rng_args = [ex.args[i+1].args[1] for i = 1:n_rng]
206-
rngs = [Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng]
207-
rng_lengths = map(length, rngs)
208-
209-
f = gensym()
210-
f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1])))
211-
212-
# TODO figure out a generic way of doing this...
213-
if n_rng == 1
214-
exprs = [:($f($j1)) for j1 in rngs[1]]
215-
elseif n_rng == 2
216-
exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]]
217-
elseif n_rng == 3
218-
exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]]
219-
elseif n_rng == 4
220-
exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]]
221-
elseif n_rng == 5
222-
exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]]
223-
elseif n_rng == 6
224-
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]]
225-
elseif n_rng == 7
226-
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]]
227-
elseif n_rng == 8
228-
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]]
229-
else
230-
error("@MArray only supports up to 8-dimensional comprehensions")
231-
end
232-
233-
return quote
234-
$(esc(f_expr))
235-
$(esc(Expr(:call, Expr(:curly, :MArray, Tuple{rng_lengths...}, T), Expr(:tuple, exprs...))))
236-
end
237-
elseif isa(ex, Expr) && ex.head == :call
238-
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
239-
if length(ex.args) == 1
240-
error("@MArray got bad expression: $(ex.args[1])()")
241-
else
242-
return quote
243-
if isa($(esc(ex.args[2])), DataType)
244-
$(ex.args[1])($(esc(Expr(:curly, MArray, Expr(:curly, Tuple, ex.args[3:end]...), ex.args[2]))))
245-
else
246-
$(ex.args[1])($(esc(Expr(:curly, MArray, Expr(:curly, Tuple, ex.args[2:end]...)))))
247-
end
248-
end
249-
end
250-
elseif ex.args[1] == :fill
251-
if length(ex.args) == 1
252-
error("@MArray got bad expression: $(ex.args[1])()")
253-
elseif length(ex.args) == 2
254-
error("@MArray got bad expression: $(ex.args[1])($(ex.args[2]))")
255-
else
256-
return quote
257-
$(esc(ex.args[1]))($(esc(ex.args[2])), MArray{$(esc(Expr(:curly, Tuple, ex.args[3:end]...)))})
258-
end
259-
end
260-
else
261-
error("@MArray only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
262-
end
263-
else
264-
error("Bad input for @MArray")
265-
end
113+
"""
114+
@MArray [a b; c d]
115+
@MArray [[a, b];[c, d]]
116+
@MArray [i+j for i in 1:2, j in 1:2]
117+
@MArray ones(2, 2, 2)
266118
119+
A convenience macro to construct `MArray` with arbitrary dimension.
120+
See [`@SArray`](@ref) for detailed features.
121+
"""
122+
macro MArray(ex)
123+
esc(static_array_gen(MArray, ex, __module__))
267124
end
268125

269126
function promote_rule(::Type{<:MArray{S,T,N,L}}, ::Type{<:MArray{S,U,N,L}}) where {S,T,U,N,L}

src/MMatrix.jl

Lines changed: 11 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -62,116 +62,15 @@ end
6262
## MMatrix methods ##
6363
#####################
6464

65-
macro MMatrix(ex)
66-
if !isa(ex, Expr)
67-
error("Bad input for @MMatrix")
68-
end
69-
if ex.head == :vect && length(ex.args) == 1 # 1 x 1
70-
return esc(Expr(:call, MMatrix{1, 1}, Expr(:tuple, ex.args[1])))
71-
elseif ex.head == :ref && length(ex.args) == 2 # typed, 1 x 1
72-
return esc(Expr(:call, Expr(:curly, :MMatrix, 1, 1, ex.args[1]), Expr(:tuple, ex.args[2])))
73-
elseif ex.head == :hcat # 1 x n
74-
s1 = 1
75-
s2 = length(ex.args)
76-
return esc(Expr(:call, MMatrix{s1, s2}, Expr(:tuple, ex.args...)))
77-
elseif ex.head == :typed_hcat # typed, 1 x n
78-
s1 = 1
79-
s2 = length(ex.args) - 1
80-
return esc(Expr(:call, Expr(:curly, :MMatrix, s1, s2, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
81-
elseif ex.head == :vcat
82-
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m
83-
# Validate
84-
s1 = length(ex.args)
85-
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1)
86-
s2 = minimum(s2s)
87-
if maximum(s2s) != s2
88-
throw(ArgumentError("Rows must be of matching lengths"))
89-
end
90-
91-
exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2]
92-
return esc(Expr(:call, MMatrix{s1, s2}, Expr(:tuple, exprs...)))
93-
else # n x 1
94-
return esc(Expr(:call, MMatrix{length(ex.args), 1}, Expr(:tuple, ex.args...)))
95-
end
96-
elseif ex.head == :typed_vcat
97-
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m
98-
# Validate
99-
s1 = length(ex.args) - 1
100-
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1)
101-
s2 = minimum(s2s)
102-
if maximum(s2s) != s2
103-
throw(ArgumentError("Rows must be of matching lengths"))
104-
end
105-
106-
exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
107-
return esc(Expr(:call, Expr(:curly, :MMatrix,s1, s2, ex.args[1]), Expr(:tuple, exprs...)))
108-
else # typed, n x 1
109-
return esc(Expr(:call, Expr(:curly, :MMatrix, length(ex.args)-1, 1, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
110-
end
111-
elseif isa(ex, Expr) && ex.head == :comprehension
112-
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
113-
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
114-
end
115-
ex = ex.args[1]
116-
if length(ex.args) != 3
117-
error("Use a 2-dimensional comprehension for @MMatrx")
118-
end
119-
120-
rng1 = Core.eval(__module__, ex.args[2].args[2])
121-
rng2 = Core.eval(__module__, ex.args[3].args[2])
122-
f = gensym()
123-
f_expr = :($f = (($(ex.args[2].args[1]), $(ex.args[3].args[1])) -> $(ex.args[1])))
124-
exprs = [:($f($j1, $j2)) for j1 in rng1, j2 in rng2]
125-
126-
return quote
127-
$(esc(f_expr))
128-
$(esc(Expr(:call, Expr(:curly, :MMatrix, length(rng1), length(rng2)), Expr(:tuple, exprs...))))
129-
end
130-
elseif isa(ex, Expr) && ex.head == :typed_comprehension
131-
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
132-
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
133-
end
134-
T = ex.args[1]
135-
ex = ex.args[2]
136-
if length(ex.args) != 3
137-
error("Use a 2-dimensional comprehension for @MMatrx")
138-
end
139-
140-
rng1 = Core.eval(__module__, ex.args[2].args[2])
141-
rng2 = Core.eval(__module__, ex.args[3].args[2])
142-
f = gensym()
143-
f_expr = :($f = (($(ex.args[2].args[1]), $(ex.args[3].args[1])) -> $(ex.args[1])))
144-
exprs = [:($f($j1, $j2)) for j1 in rng1, j2 in rng2]
65+
"""
66+
@MMatrix [a b c d]
67+
@MMatrix [[a, b];[c, d]]
68+
@MMatrix [i+j for i in 1:2, j in 1:2]
69+
@MMatrix ones(2, 2, 2)
14570
146-
return quote
147-
$(esc(f_expr))
148-
$(esc(Expr(:call, Expr(:curly, :MMatrix, length(rng1), length(rng2), T), Expr(:tuple, exprs...))))
149-
end
150-
elseif isa(ex, Expr) && ex.head == :call
151-
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
152-
if length(ex.args) == 3
153-
return quote
154-
$(ex.args[1])(MMatrix{$(esc(ex.args[2])),$(esc(ex.args[3]))})
155-
end
156-
elseif length(ex.args) == 4
157-
return quote
158-
$(ex.args[1])(MMatrix{$(esc(ex.args[3])), $(esc(ex.args[4])), $(esc(ex.args[2]))})
159-
end
160-
else
161-
error("@MMatrix expected a 2-dimensional array expression")
162-
end
163-
elseif ex.args[1] == :fill
164-
if length(ex.args) == 4
165-
return quote
166-
$(esc(ex.args[1]))($(esc(ex.args[2])), MMatrix{$(esc(ex.args[3])), $(esc(ex.args[4]))})
167-
end
168-
else
169-
error("@MMatrix expected a 2-dimensional array expression")
170-
end
171-
else
172-
error("@MMatrix only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
173-
end
174-
else
175-
error("Bad input for @MMatrix")
176-
end
177-
end
71+
A convenience macro to construct `MMatrix`.
72+
See [`@SArray`](@ref) for detailed features.
73+
"""
74+
macro MMatrix(ex)
75+
esc(static_matrix_gen(MMatrix, ex, __module__))
76+
end

src/MVector.jl

Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -28,76 +28,16 @@ const MVector{S, T} = MArray{Tuple{S}, T, 1, S}
2828
#####################
2929
## MVector methods ##
3030
#####################
31+
"""
32+
@MVector [a, b, c, d]
33+
@MVector [i for i in 1:2]
34+
@MVector ones(2)
3135
36+
A convenience macro to construct `MVector`.
37+
See [`@SArray`](@ref) for detailed features.
38+
"""
3239
macro MVector(ex)
33-
if isa(ex, Expr) && ex.head == :vect
34-
return esc(Expr(:call, MVector{length(ex.args)}, Expr(:tuple, ex.args...)))
35-
elseif isa(ex, Expr) && ex.head == :ref
36-
return esc(Expr(:call, Expr(:curly, :MVector, length(ex.args[2:end]), ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
37-
elseif isa(ex, Expr) && ex.head == :comprehension
38-
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
39-
error("Expected generator in comprehension, e.g. [f(i) for i = 1:3]")
40-
end
41-
ex = ex.args[1]
42-
if length(ex.args) != 2
43-
error("Use a one-dimensional comprehension for @MVector")
44-
end
45-
46-
rng = Core.eval(__module__, ex.args[2].args[2])
47-
f = gensym()
48-
f_expr = :($f = ($(ex.args[2].args[1]) -> $(ex.args[1])))
49-
exprs = [:($f($j)) for j in rng]
50-
51-
return quote
52-
$(esc(f_expr))
53-
$(esc(Expr(:call, Expr(:curly, :MVector, length(rng)), Expr(:tuple, exprs...))))
54-
end
55-
elseif isa(ex, Expr) && ex.head == :typed_comprehension
56-
if length(ex.args) != 2 || !isa(ex.args[2], Expr) !! ex.args[2].head != :generator
57-
error("Expected generator in typed comprehension, e.g. Float64[f(i) for i = 1:3]")
58-
end
59-
T = ex.args[1]
60-
ex = ex.args[2]
61-
if length(ex.args) != 2
62-
error("Use a one-dimensional comprehension for @MVector")
63-
end
64-
65-
rng = Core.eval(__module__, ex.args[2].args[2])
66-
f = gensym()
67-
f_expr = :($f = ($(ex.args[2].args[1]) -> $(ex.args[1])))
68-
exprs = [:($f($j)) for j in rng]
69-
70-
return quote
71-
$(esc(f_expr))
72-
$(esc(Expr(:call, Expr(:curly, :MVector, length(rng), T), Expr(:tuple, exprs...))))
73-
end
74-
elseif isa(ex, Expr) && ex.head == :call
75-
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
76-
if length(ex.args) == 2
77-
return quote
78-
$(esc(ex.args[1]))(MVector{$(esc(ex.args[2]))})
79-
end
80-
elseif length(ex.args) == 3
81-
return quote
82-
$(esc(ex.args[1]))(MVector{$(esc(ex.args[3])), $(esc(ex.args[2]))})
83-
end
84-
else
85-
error("@MVector expected a 1-dimensional array expression")
86-
end
87-
elseif ex.args[1] == :fill
88-
if length(ex.args) == 3
89-
return quote
90-
$(esc(ex.args[1]))($(esc(ex.args[2])), MVector{$(esc(ex.args[3]))})
91-
end
92-
else
93-
error("@MVector expected a 1-dimensional array expression")
94-
end
95-
else
96-
error("@MVector only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
97-
end
98-
else
99-
error("Use @MVector [a,b,c] or @MVector([a,b,c])")
100-
end
40+
esc(static_vector_gen(MVector, ex, __module__))
10141
end
10242

10343
# Named field access for the first four elements, using the conventional field

0 commit comments

Comments
 (0)