Skip to content

Commit 23269d8

Browse files
feat: add Code.create_array method for TrackedArray in ReverseDiffExt
1 parent 3303acf commit 23269d8

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "3.5.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -28,12 +29,15 @@ Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
2829

2930
[weakdeps]
3031
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
32+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3133

3234
[extensions]
3335
SymbolicUtilsLabelledArraysExt = "LabelledArrays"
36+
SymbolicUtilsReverseDiffExt = "ReverseDiff"
3437

3538
[compat]
3639
AbstractTrees = "0.4"
40+
ArrayInterface = "7.8"
3741
Bijections = "0.1.2"
3842
ChainRulesCore = "1"
3943
Combinatorics = "1.0"
@@ -45,6 +49,7 @@ IfElse = "0.1"
4549
LabelledArrays = "1.5"
4650
MultivariatePolynomials = "0.5"
4751
NaNMath = "0.3, 1"
52+
ReverseDiff = "1"
4853
Setfield = "0.7, 0.8, 1"
4954
SpecialFunctions = "0.10, 1.0, 2"
5055
StaticArrays = "0.12, 1.0"

ext/SymbolicUtilsReverseDiffExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module SymbolicUtilsReverseDiffExt
2+
3+
using ReverseDiff
4+
using SymbolicUtils
5+
6+
@inline function SymbolicUtils.Code.create_array(::Type{<:ReverseDiff.TrackedArray}, T, v1::Val, v2::Val{dims}, elems...) where dims
7+
SymbolicUtils.ArrayInterface.aos_to_soa(SymbolicUtils.Code.create_array(Array, T, v1, v2, elems...))
8+
end
9+
10+
end

src/SymbolicUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ using ConstructionBase
1818
using TermInterface
1919
import TermInterface: iscall, isexpr, head, children,
2020
operation, arguments, metadata, maketerm, sorted_arguments
21+
# For ReverseDiffExt
22+
import ArrayInterface
2123

2224
Base.@deprecate istree iscall
2325
export istree, operation, arguments, sorted_arguments, similarterm, iscall

0 commit comments

Comments
 (0)