Skip to content

Commit 1fced24

Browse files
Merge branch 'JuliaSymbolics:master' into adjoint_symbolics
2 parents ba291c6 + aab293a commit 1fced24

File tree

5 files changed

+43
-26
lines changed

5 files changed

+43
-26
lines changed

Project.toml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "3.2.0"
4+
version = "3.5.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -26,15 +26,21 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2626
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2727
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
2828

29+
[weakdeps]
30+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
31+
32+
[extensions]
33+
SymbolicUtilsLabelledArraysExt = "LabelledArrays"
34+
2935
[compat]
3036
AbstractTrees = "0.4"
3137
Bijections = "0.1.2"
3238
ChainRulesCore = "1"
3339
Combinatorics = "1.0"
34-
ConstructionBase = "1.1"
40+
ConstructionBase = "1.5.7"
3541
DataStructures = "0.18"
3642
DocStringExtensions = "0.8, 0.9"
37-
DynamicPolynomials = "0.5"
43+
DynamicPolynomials = "0.5, 0.6"
3844
IfElse = "0.1"
3945
LabelledArrays = "1.5"
4046
MultivariatePolynomials = "0.5"
@@ -51,6 +57,7 @@ julia = "1.3"
5157
[extras]
5258
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5359
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
60+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
5461
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5562
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
5663
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -59,4 +66,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5966
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6067

6168
[targets]
62-
test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]
69+
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]

ext/SymbolicUtilsLabelledArraysExt.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module SymbolicUtilsLabelledArraysExt
2+
3+
using LabelledArrays
4+
using LabelledArrays.StaticArrays
5+
using SymbolicUtils
6+
7+
@inline function SymbolicUtils.Code.create_array(A::Type{<:SLArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
8+
a = SymbolicUtils.Code.create_array(SArray, T, nd, d, elems...)
9+
if nfields(dims) === ndims(A)
10+
similar_type(A, eltype(a), Size(dims))(a)
11+
else
12+
a
13+
end
14+
end
15+
16+
@inline function SymbolicUtils.Code.create_array(A::Type{<:LArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
17+
data = SymbolicUtils.Code.create_array(Array, T, nd, d, elems...)
18+
if nfields(dims) === ndims(A)
19+
LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data)
20+
else
21+
data
22+
end
23+
end
24+
25+
end

src/code.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Code
22

3-
using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
3+
using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
44

55
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
66
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
@@ -578,25 +578,6 @@ end
578578
MArray{Tuple{dims...}, T}(elems...)
579579
end
580580

581-
## LabelledArrays
582-
@inline function create_array(A::Type{<:SLArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
583-
a = create_array(SArray, T, nd, d, elems...)
584-
if nfields(dims) === ndims(A)
585-
similar_type(A, eltype(a), Size(dims))(a)
586-
else
587-
a
588-
end
589-
end
590-
591-
@inline function create_array(A::Type{<:LArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
592-
data = create_array(Array, T, nd, d, elems...)
593-
if nfields(dims) === ndims(A)
594-
LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data)
595-
else
596-
data
597-
end
598-
end
599-
600581
## We use a separate type for Sparse Arrays to sidestep the need for
601582
## iszero to be defined on the expression type
602583
@matchable struct MakeSparseArray{S<:AbstractSparseArray}

src/types.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ isdiv(x) = isa_SymType(Val(:Div), x)
225225

226226
Base.isequal(::Symbolic, x) = false
227227
Base.isequal(x, ::Symbolic) = false
228+
Base.isequal(::Symbolic, ::Missing) = false
229+
Base.isequal(::Missing, ::Symbolic) = false
228230
Base.isequal(::Symbolic, ::Symbolic) = false
229231
coeff_isequal(a, b) = isequal(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
230232
function _allarequal(xs, ys)::Bool
@@ -996,10 +998,9 @@ variable. So, `h(1, g)` will fail and `h(1, f)` will work.
996998
"""
997999
macro syms(xs...)
9981000
defs = map(xs) do x
999-
n, t = _name_type(x)
1000-
T = esc(t)
10011001
nt = _name_type(x)
10021002
n, t = nt.name, nt.type
1003+
T = esc(t)
10031004
:($(esc(n)) = Sym{$T}($(Expr(:quote, n))))
10041005
end
10051006
Expr(:block, defs...,

test/basics.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ end
312312
@syms a b c
313313
@test isequal(a + b, a + b + 0.01 - 0.01)
314314
@test isequal(a + NaN, a + NaN)
315+
316+
@test !isequal(a, missing)
317+
@test !isequal(missing, b)
315318
end
316319

317320
@testset "subtyping" begin

0 commit comments

Comments
 (0)