Skip to content

Commit 4db80f1

Browse files
authored
Small tweaks to prevent == invalidation (#950)
These improve inferrence to guard against very minor invalidations from packages that extend `==`.
1 parent f26c9ac commit 4db80f1

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

src/SMatrix.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,37 +89,45 @@ macro SMatrix(ex)
8989
s2 = length(ex.args) - 1
9090
return esc(Expr(:call, Expr(:curly, :SMatrix, s1, s2, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
9191
elseif ex.head == :vcat
92-
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m
92+
if isa(ex.args[1], Expr) && (ex.args[1]::Expr).head == :row # n x m
9393
# Validate
9494
s1 = length(ex.args)
95-
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1)
95+
s2s = let s1=s1, ex=ex
96+
map(i -> ((isa(ex.args[i], Expr) && (ex.args[i]::Expr).head == :row) ? length((ex.args[i]::Expr).args) : 1), 1:s1)
97+
end
9698
s2 = minimum(s2s)
9799
if maximum(s2s) != s2
98100
throw(ArgumentError("Rows must be of matching lengths"))
99101
end
100102

101-
exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2]
103+
exprs = let s1=s1, s2=s2, ex=ex
104+
[ex.args[i].args[j] for i = 1:s1, j = 1:s2]
105+
end
102106
return esc(Expr(:call, SMatrix{s1, s2}, Expr(:tuple, exprs...)))
103107
else # n x 1
104108
return esc(Expr(:call, SMatrix{length(ex.args), 1}, Expr(:tuple, ex.args...)))
105109
end
106110
elseif ex.head == :typed_vcat
107-
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m
111+
if isa(ex.args[2], Expr) && (ex.args[2]::Expr).head == :row # typed, n x m
108112
# Validate
109113
s1 = length(ex.args) - 1
110-
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1)
114+
s2s = let s1=s1, ex=ex
115+
map(i -> ((isa(ex.args[i+1], Expr) && (ex.args[i+1]::Expr).head == :row) ? length((ex.args[i+1]::Expr).args) : 1), 1:s1)
116+
end
111117
s2 = minimum(s2s)
112118
if maximum(s2s) != s2
113119
throw(ArgumentError("Rows must be of matching lengths"))
114120
end
115121

116-
exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
122+
exprs = let s1=s1, s2=s2, ex=ex
123+
[ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
124+
end
117125
return esc(Expr(:call, Expr(:curly, :SMatrix,s1, s2, ex.args[1]), Expr(:tuple, exprs...)))
118126
else # typed, n x 1
119127
return esc(Expr(:call, Expr(:curly, :SMatrix, length(ex.args)-1, 1, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
120128
end
121129
elseif isa(ex, Expr) && ex.head == :comprehension
122-
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
130+
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || (ex.args[1]::Expr).head != :generator
123131
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
124132
end
125133
ex = ex.args[1]
@@ -138,7 +146,7 @@ macro SMatrix(ex)
138146
$(esc(Expr(:call, Expr(:curly, :SMatrix, length(rng1), length(rng2)), Expr(:tuple, exprs...))))
139147
end
140148
elseif isa(ex, Expr) && ex.head == :typed_comprehension
141-
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
149+
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || (ex.args[2]::Expr).head != :generator
142150
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
143151
end
144152
T = ex.args[1]
@@ -158,7 +166,7 @@ macro SMatrix(ex)
158166
$(esc(Expr(:call, Expr(:curly, :SMatrix, length(rng1), length(rng2), T), Expr(:tuple, exprs...))))
159167
end
160168
elseif isa(ex, Expr) && ex.head == :call
161-
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
169+
if ex.args[1] === :zeros || ex.args[1] === :ones || ex.args[1] === :rand || ex.args[1] === :randn || ex.args[1] === :randexp
162170
if length(ex.args) == 3
163171
return quote
164172
$(ex.args[1])(SMatrix{$(esc(ex.args[2])),$(esc(ex.args[3]))})
@@ -170,7 +178,7 @@ macro SMatrix(ex)
170178
else
171179
error("@SMatrix expected a 2-dimensional array expression")
172180
end
173-
elseif ex.args[1] == :fill
181+
elseif ex.args[1] === :fill
174182
if length(ex.args) == 4
175183
return quote
176184
$(esc(ex.args[1]))($(esc(ex.args[2])), SMatrix{$(esc(ex.args[3])), $(esc(ex.args[4]))})

src/matrix_multiply_add.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,11 @@ end
390390

391391
function combine_products(expr_list)
392392
filtered = filter(expr_list) do expr
393-
if expr.head != :call || expr.args[1] != :*
393+
if expr.head != :call || expr.args[1] !== :*
394394
error("expected call to *")
395395
end
396396
for arg in expr.args[2:end]
397-
if isa(arg, Expr) && arg.head == :call && arg.args[1] == :zero
397+
if isa(arg, Expr) && arg.head == :call && arg.args[1] === :zero
398398
return false
399399
end
400400
end
@@ -404,7 +404,7 @@ function combine_products(expr_list)
404404
return :(zero(T))
405405
else
406406
return reduce(filtered) do ex1, ex2
407-
if ex2.head != :call || ex2.args[1] != :*
407+
if ex2.head != :call || ex2.args[1] !== :*
408408
error("expected call to *")
409409
end
410410

0 commit comments

Comments
 (0)