Skip to content

Commit 5c0b782

Browse files
Merge pull request #456 from AayushSabharwal/as/restructure-adjoint
feat: add adjoint for `ArrayInterface.restructure`
2 parents 8594f42 + c7216c3 commit 5c0b782

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1313
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
1414
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
15+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1516
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1617
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1718
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -24,6 +25,7 @@ ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
2425
ArrayInterfaceCUDAExt = "CUDA"
2526
ArrayInterfaceCUDSSExt = "CUDSS"
2627
ArrayInterfaceChainRulesExt = "ChainRules"
28+
ArrayInterfaceChainRulesCoreExt = "ChainRulesCore"
2729
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
2830
ArrayInterfaceReverseDiffExt = "ReverseDiff"
2931
ArrayInterfaceSparseArraysExt = "SparseArrays"
@@ -37,6 +39,8 @@ BlockBandedMatrices = "0.13"
3739
CUDA = "5"
3840
CUDSS = "0.2, 0.3"
3941
ChainRules = "1"
42+
ChainRulesCore = "1"
43+
ChainRulesTestUtils = "1"
4044
GPUArraysCore = "0.1, 0.2"
4145
LinearAlgebra = "1.10"
4246
ReverseDiff = "1"
@@ -51,6 +55,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
5155
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
5256
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5357
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
58+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5459
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
5560
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
5661
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
@@ -66,4 +71,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6671
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
6772

6873
[targets]
69-
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"]
74+
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays", "ChainRulesTestUtils"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module ArrayInterfaceChainRulesCoreExt
2+
3+
import ArrayInterface
4+
import ChainRulesCore
5+
import ChainRulesCore: unthunk, NoTangent, ZeroTangent, ProjectTo, @thunk
6+
7+
function ChainRulesCore.rrule(::typeof(ArrayInterface.restructure), target, src)
8+
projectT = ProjectTo(target)
9+
function restructure_pullback(dt)
10+
dt = unthunk(dt)
11+
12+
= NoTangent()
13+
= ZeroTangent()
14+
= @thunk(projectT(ArrayInterface.restructure(src, dt)))
15+
16+
f̄, t̄, s̄
17+
end
18+
19+
return ArrayInterface.restructure(target, src), restructure_pullback
20+
end
21+
22+
end

test/chainrules.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
using ArrayInterface, ChainRules, Test
2+
using ComponentArrays, ChainRulesTestUtils, StaticArrays
23

34
x = ChainRules.OneElement(3.0, (3, 3), (1:4, 1:4))
45

56
@test !ArrayInterface.can_setindex(x)
67
@test !ArrayInterface.can_setindex(typeof(x))
8+
9+
arr = ComponentArray(a = 1.0, b = [2.0, 3.0], c = (; a = 4.0, b = 5.0), d = SVector{2}(6.0, 7.0))
10+
b = zeros(length(arr))
11+
12+
ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, arr, b)
13+
ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, b, arr)

0 commit comments

Comments
 (0)