Skip to content

WIP: Add ChainRules #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d2c7537
Add and test rrule for Woodbury
Jul 27, 2021
c586180
Add manifest for Zygote branch
Jul 27, 2021
10bc061
Tidy up tests. White space clear
Jul 27, 2021
13054b4
Remove ZYgote from deps
Jul 27, 2021
d149676
Switch to Zygote master now it supports cr1 until tagged
Jul 27, 2021
49f1ee6
Extend Distirbutions for compat reasons
Jul 27, 2021
89947f6
Update manifest to use zygote#master until tagged
Jul 27, 2021
e81dff8
Add Zygote again to resolve CI issues
Jul 27, 2021
5dce0e9
Generate manifest on 1.6
Jul 27, 2021
2e31400
Readd diagonal tests after rebase removal
Jul 28, 2021
18a4b99
Rebuild manifest on exact CI version...
Jul 28, 2021
5f91685
Compat Bounds
Jul 28, 2021
bbfe9b2
MR comments1
Jul 28, 2021
c79f177
Update src/chainrules.jl
AlexRobson Jul 28, 2021
49f6795
Update src/chainrules.jl
AlexRobson Jul 28, 2021
64e7780
Update test/chainrules.jl
AlexRobson Jul 28, 2021
b3ff0c2
Fix up pullback
Jul 28, 2021
ca45e00
passing Tangent{Woodbury} and type inference fixed
Jul 29, 2021
f6db1da
Use Functor for ProjectTo. Update test
Jul 30, 2021
ed3ee75
Remove irrelevant Diagonal testse
Jul 30, 2021
0edc740
Merge branch 'ar/chainrules' of https://github.com/invenia/PDMatsExtr…
Jul 30, 2021
b6e94f8
Remove merge mess
Jul 30, 2021
dde2b46
Rework chainrules tests. Add constructor rrule
Aug 15, 2021
73f6bbc
Remove ChainRules from test deps
Aug 15, 2021
251d183
White space deletion
Aug 15, 2021
31cf7e1
Remove Zygote from deps
Aug 15, 2021
a9900c8
Remove ChainRules from extras
Aug 15, 2021
f75ed2e
Remove ChainRules from compat
Aug 15, 2021
955b2c1
Refactor long line. Add comment
Aug 15, 2021
2e7babe
Add projections and thunks into _times_pullback
Aug 15, 2021
e1bd5f9
Tangent{T} T<:Woodbury -> just Tangent
Aug 15, 2021
9431ead
Add in extra test that test a Matrix input to tangent
Aug 15, 2021
8ba6f0d
Add Matrix tangent tests into a seperate test set
Aug 15, 2021
ff8d903
Update src/chainrules.jl
AlexRobson Aug 16, 2021
eb23510
Update src/chainrules.jl
AlexRobson Aug 16, 2021
9535406
Update test/chainrules.jl
AlexRobson Aug 16, 2021
a9cdd38
Add (and remove) comments
Aug 16, 2021
eb194c5
Readd CHainRules as a test dep
Aug 16, 2021
c7e98c1
ChainRules added explicitely to runtests
Aug 16, 2021
8ce1498
Add comment vefore test set
Aug 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 288 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# This file is machine-generated - editing it directly is not advised

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "346588c81effb94da6a30c1617e56af6a878e4d6"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.0.1"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "ad613c934ec3a3aa0ff19b91f15a16d56ed404b5"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.0.2"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.31.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.0.3"

[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.0.2"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.5"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.1"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.18"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.3"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.3.0"

[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"

[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"

[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["DocStringExtensions", "LinearAlgebra"]
git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.5"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.6"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"

[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"

[[PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "4dd403333bcf0909341cfe57ec115152f937d7d8"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.1"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.2"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "508822dca004bf62e210609148511ad03ce8f1d8"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.6.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "5b2f81eeb66bcfe379947c500aae773c85c31033"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.8"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"

[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "1b6f1725e0e5b70885845329d0a599edf34dbae6"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.18-DEV"

[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.1"

[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "0.9.17, 0.10"
Distributions = "0.23, 0.24"
ChainRulesCore = "1"
Distributions = "0.23, 0.24, 0.25"
FiniteDifferences = "0.11, 0.12"
PDMats = "0.9, 0.10, 0.11"
Zygote = "0.5.5"
Zygote = "0.6"
julia = "1"

[extras]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -27,4 +30,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]
test = ["ChainRulesTestUtils", "ChainRules", "Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]
1 change: 1 addition & 0 deletions src/PDMatsExtras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ export submat
include("psd_mat.jl")
include("woodbury_pd_mat.jl")
include("utils.jl")
include("chainrules.jl")

end
40 changes: 40 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@non_differentiable validate_woodbury_arguments(A, D, S)

function ChainRulesCore.rrule(
::typeof(*), A::Real, B::WoodburyPDMat{T, TA, TD, TS}
) where {T, TA, TD, TS}
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ::AbstractMatrix)
Ȳ = unthunk(ȳ)
Ā = dot(Ȳ, B)
B̄ = A' * Ȳ
return (
NoTangent(),
@thunk(project_A(Ā')),
@thunk(project_B(B̄)),
)
end

function times_pullback(ȳ::Tangent{<:WoodburyPDMat})
Ȳ = unthunk(ȳ)
Ā = dot(Ȳ.A * Ȳ.D * Ȳ.A' + Ȳ.S, B)
B̄ = Ȳ.A * (A' * Ȳ.D) * Ȳ.A' + A' * Ȳ.S
return (
NoTangent(),
@thunk(project_A(Ā')),
@thunk(project_B(B̄)),
)
end
return A * B, times_pullback
end

function ChainRulesCore.ProjectTo(W::WoodburyPDMat)
function dW(W̄)
Ā(W̄) = ProjectTo(W.A)(collect((W.D * W.A' * W̄' + W.D * W.A' * W̄)'))
D̄(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A)
S̄(W̄) = ProjectTo(W.S)(W̄)
return Tangent{typeof(W)}(; A = Ā(W̄), D = D̄(W̄), S = S̄(W̄))
end
return dW
end
6 changes: 3 additions & 3 deletions src/woodbury_pd_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ function validate_woodbury_arguments(A, D, S)
end
end

@non_differentiable validate_woodbury_arguments(A, D, S)

function LinearAlgebra.logdet(W::WoodburyPDMat)
C_S = cholesky(W.S)
B = C_S.U' \ (W.A * cholesky(W.D).U')
Expand All @@ -90,4 +88,6 @@ end
# NOTE: the parameterisation to scale up the Woodbury matrix is not unique. Here we
# implement one way to scale it.
*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
*(c::Real, a::WoodburyPDMat) = a * c
*(c::Real, a::WoodburyPDMat) = a * c
*(c::Diagonal{T}, a::WoodburyPDMat) where {T<:Real} = c * Matrix(a)
*(a::WoodburyPDMat, c::Diagonal{T}) where {T<:Real} = Matrix(a) * c
Loading