Skip to content

Commit aece705

Browse files
committed
make LabelledArrays an extension
1 parent ea34caf commit aece705

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2626
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2727
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
2828

29+
[weakdeps]
30+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
31+
32+
[extensions]
33+
SymbolicUtilsLabelledArraysExt = "LabelledArrays"
34+
2935
[compat]
3036
AbstractTrees = "0.4"
3137
Bijections = "0.1.2"
@@ -51,6 +57,7 @@ julia = "1.3"
5157
[extras]
5258
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5359
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
60+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
5461
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5562
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
5663
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -59,4 +66,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5966
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6067

6168
[targets]
62-
test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]
69+
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]

ext/SymbolicUtilsLabelledArraysExt.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module SymbolicUtilsLabelledArraysExt
2+
3+
using LabelledArrays
4+
using LabelledArrays.StaticArrays
5+
using SymbolicUtils
6+
7+
@inline function SymbolicUtils.Code.create_array(A::Type{<:SLArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
8+
a = SymbolicUtils.Code.create_array(SArray, T, nd, d, elems...)
9+
if nfields(dims) === ndims(A)
10+
similar_type(A, eltype(a), Size(dims))(a)
11+
else
12+
a
13+
end
14+
end
15+
16+
@inline function SymbolicUtils.Code.create_array(A::Type{<:LArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
17+
data = SymbolicUtils.Code.create_array(Array, T, nd, d, elems...)
18+
if nfields(dims) === ndims(A)
19+
LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data)
20+
else
21+
data
22+
end
23+
end
24+
25+
end

src/code.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Code
22

3-
using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
3+
using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
44

55
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
66
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
@@ -578,25 +578,6 @@ end
578578
MArray{Tuple{dims...}, T}(elems...)
579579
end
580580

581-
## LabelledArrays
582-
@inline function create_array(A::Type{<:SLArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
583-
a = create_array(SArray, T, nd, d, elems...)
584-
if nfields(dims) === ndims(A)
585-
similar_type(A, eltype(a), Size(dims))(a)
586-
else
587-
a
588-
end
589-
end
590-
591-
@inline function create_array(A::Type{<:LArray}, T, nd::Val, d::Val{dims}, elems...) where {dims}
592-
data = create_array(Array, T, nd, d, elems...)
593-
if nfields(dims) === ndims(A)
594-
LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data)
595-
else
596-
data
597-
end
598-
end
599-
600581
## We use a separate type for Sparse Arrays to sidestep the need for
601582
## iszero to be defined on the expression type
602583
@matchable struct MakeSparseArray{S<:AbstractSparseArray}

0 commit comments

Comments
 (0)