Skip to content

Commit b16b285

Browse files
Merge pull request #646 from AayushSabharwal/as/reversediff-ext
feat: add `Code.create_array` method for `TrackedArray` in ReverseDiffExt
2 parents 3303acf + b409ba6 commit b16b285

File tree

5 files changed

+32
-2
lines changed

5 files changed

+32
-2
lines changed

.github/workflows/benchmark_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- uses: actions/checkout@v2
1717
- uses: julia-actions/setup-julia@v1
1818
with:
19-
version: "1.8"
19+
version: "1"
2020
- uses: julia-actions/cache@v1
2121
- name: Extract Package Name from Project.toml
2222
id: extract-package-name

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "3.5.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -28,12 +29,15 @@ Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
2829

2930
[weakdeps]
3031
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
32+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3133

3234
[extensions]
3335
SymbolicUtilsLabelledArraysExt = "LabelledArrays"
36+
SymbolicUtilsReverseDiffExt = "ReverseDiff"
3437

3538
[compat]
3639
AbstractTrees = "0.4"
40+
ArrayInterface = "7.8"
3741
Bijections = "0.1.2"
3842
ChainRulesCore = "1"
3943
Combinatorics = "1.0"
@@ -45,6 +49,7 @@ IfElse = "0.1"
4549
LabelledArrays = "1.5"
4650
MultivariatePolynomials = "0.5"
4751
NaNMath = "0.3, 1"
52+
ReverseDiff = "1"
4853
Setfield = "0.7, 0.8, 1"
4954
SpecialFunctions = "0.10, 1.0, 2"
5055
StaticArrays = "0.12, 1.0"
@@ -62,8 +67,9 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
6267
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
6368
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6469
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
70+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
6571
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6672
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6773

6874
[targets]
69-
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"]
75+
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "Test", "Zygote"]

ext/SymbolicUtilsReverseDiffExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module SymbolicUtilsReverseDiffExt
2+
3+
using ReverseDiff
4+
using SymbolicUtils
5+
6+
@inline function SymbolicUtils.Code.create_array(::Type{<:ReverseDiff.TrackedArray}, T, v1::Val, v2::Val{dims}, elems...) where dims
7+
SymbolicUtils.ArrayInterface.aos_to_soa(SymbolicUtils.Code.create_array(Array, T, v1, v2, elems...))
8+
end
9+
10+
end

src/SymbolicUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ using ConstructionBase
1818
using TermInterface
1919
import TermInterface: iscall, isexpr, head, children,
2020
operation, arguments, metadata, maketerm, sorted_arguments
21+
# For ReverseDiffExt
22+
import ArrayInterface
2123

2224
Base.@deprecate istree iscall
2325
export istree, operation, arguments, sorted_arguments, similarterm, iscall

test/code.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using SymbolicUtils.Code: LazyState
55
using StaticArrays
66
using LabelledArrays
77
using SparseArrays
8+
using ReverseDiff
89
using LinearAlgebra
910

1011
test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b))
@@ -158,6 +159,17 @@ nanmath_st.rewrites[:nanmath] = true
158159
@test eval(toexpr(Let([a 1, b 2, arr @SLVector((:a, :b))(@SVector[1,2])],
159160
MakeArray([a+b,a/b], arr)))) === @SLVector((:a, :b))(@SVector [3, 1/2])
160161

162+
trackedarr = eval(toexpr(Let([a ReverseDiff.track(1.0), b 2, arr ReverseDiff.track(ones(2))],
163+
MakeArray([a+b,a/b], arr))))
164+
@test trackedarr isa ReverseDiff.TrackedArray
165+
@test trackedarr == [3, 1/2]
166+
167+
trackedarr = eval(toexpr(Let([a ReverseDiff.track(1.0), b 2, arr ReverseDiff.track(ones(2))],
168+
MakeArray([a b; a+b a/b], arr))))
169+
@test trackedarr isa ReverseDiff.TrackedArray
170+
@test trackedarr == [1 2; 3 1/2]
171+
172+
161173
R1 = eval(toexpr(Let([a 1, b 2, arr @MVector([1,2])],MakeArray([a,b,a+b,a/b], arr))))
162174
@test R1 == (@MVector [1, 2, 3, 1/2]) && R1 isa MVector
163175

0 commit comments

Comments
 (0)