Skip to content

Commit 17e8c37

Browse files
authored
nits from JET (#1095)
* nits from JET * add tests for non-`:generator` comprehension
1 parent 5f7debb commit 17e8c37

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

src/SArray.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,13 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
159159
args = parse_cat_ast(ex)
160160
return :($SA{$Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
161161
elseif head === :comprehension
162-
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
162+
if length(ex.args) != 1
163163
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
164164
end
165165
ex = ex.args[1]
166+
if !isa(ex, Expr) || (ex::Expr).head != :generator
167+
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
168+
end
166169
n_rng = length(ex.args) - 1
167170
rng_args = (ex.args[i+1].args[1] for i = 1:n_rng)
168171
rngs = Any[Core.eval(mod, ex.args[i+1].args[2]) for i = 1:n_rng]
@@ -174,11 +177,14 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
174177
end
175178
end
176179
elseif head === :typed_comprehension
177-
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
180+
if length(ex.args) != 2
178181
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
179182
end
180183
T = esc(ex.args[1])
181184
ex = ex.args[2]
185+
if !isa(ex, Expr) || (ex::Expr).head != :generator
186+
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
187+
end
182188
n_rng = length(ex.args) - 1
183189
rng_args = (ex.args[i+1].args[1] for i = 1:n_rng)
184190
rngs = Any[Core.eval(mod, ex.args[i+1].args[2]) for i = 1:n_rng]

src/SVector.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
3939
len = check_vector_length(size(args))
4040
return :($SV{$len}($tuple($(escall(args)...))))
4141
elseif head === :comprehension
42-
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
42+
if length(ex.args) != 1
4343
error("Expected generator in comprehension, e.g. [f(i) for i = 1:3]")
4444
end
4545
ex = ex.args[1]
46+
if !isa(ex, Expr) || (ex::Expr).head != :generator
47+
error("Expected generator in comprehension, e.g. [f(i) for i = 1:3]")
48+
end
4649
if length(ex.args) != 2
4750
error("Use a one-dimensional comprehension for @$SV")
4851
end
@@ -55,11 +58,14 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
5558
end
5659
end
5760
elseif head === :typed_comprehension
58-
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
61+
if length(ex.args) != 2
5962
error("Expected generator in typed comprehension, e.g. Float64[f(i) for i = 1:3]")
6063
end
6164
T = esc(ex.args[1])
6265
ex = ex.args[2]
66+
if !isa(ex, Expr) || (ex::Expr).head != :generator
67+
error("Expected generator in typed comprehension, e.g. Float64[f(i) for i = 1:3]")
68+
end
6369
if length(ex.args) != 2
6470
error("Use a one-dimensional comprehension for @$SV")
6571
end

src/matrix_multiply.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTria
128128
end
129129

130130
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
131-
S = Size(sa[1], sb[2])
132131
# Heuristic choice for amount of codegen
133132
a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 4 : 1
134133
b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 4 : 1

test/SArray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@
9191
test_expand_error(:(@SArray fill()))
9292
test_expand_error(:(@SArray [1; 2; 3; 4]...))
9393

94+
# (typed-)comprehension LoadError for `ex.args[1].head != :generator`
95+
test_expand_error(:(@SArray [i+j for i in 1:2 for j in 1:2]))
96+
test_expand_error(:(@SArray Int[i+j for i in 1:2 for j in 1:2]))
97+
9498
@test ((@SArray fill(1))::SArray{Tuple{},Int}).data === (1,)
9599
@test ((@SArray ones())::SArray{Tuple{},Float64}).data === (1.,)
96100

0 commit comments

Comments
 (0)