-
Notifications
You must be signed in to change notification settings - Fork 6
WIP: Add ChainRules #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 11 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
d2c7537
Add and test rrule for Woodbury
c586180
Add manifest for Zygote branch
10bc061
Tidy up tests. White space clear
13054b4
Remove ZYgote from deps
d149676
Switch to Zygote master now it supports cr1 until tagged
49f1ee6
Extend Distirbutions for compat reasons
89947f6
Update manifest to use zygote#master until tagged
e81dff8
Add Zygote again to resolve CI issues
5dce0e9
Generate manifest on 1.6
2e31400
Readd diagonal tests after rebase removal
18a4b99
Rebuild manifest on exact CI version...
5f91685
Compat Bounds
bbfe9b2
MR comments1
c79f177
Update src/chainrules.jl
AlexRobson 49f6795
Update src/chainrules.jl
AlexRobson 64e7780
Update test/chainrules.jl
AlexRobson b3ff0c2
Fix up pullback
ca45e00
passing Tangent{Woodbury} and type inference fixed
f6db1da
Use Functor for ProjectTo. Update test
ed3ee75
Remove irrelevant Diagonal testse
0edc740
Merge branch 'ar/chainrules' of https://github.com/invenia/PDMatsExtr…
b6e94f8
Remove merge mess
dde2b46
Rework chainrules tests. Add constructor rrule
73f6bbc
Remove ChainRules from test deps
251d183
White space deletion
31cf7e1
Remove Zygote from deps
a9900c8
Remove ChainRules from extras
f75ed2e
Remove ChainRules from compat
955b2c1
Refactor long line. Add comment
2e7babe
Add projections and thunks into _times_pullback
e1bd5f9
Tangent{T} T<:Woodbury -> just Tangent
9431ead
Add in extra test that test a Matrix input to tangent
8ba6f0d
Add Matrix tangent tests into a seperate test set
ff8d903
Update src/chainrules.jl
AlexRobson eb23510
Update src/chainrules.jl
AlexRobson 9535406
Update test/chainrules.jl
AlexRobson a9cdd38
Add (and remove) comments
eb194c5
Readd CHainRules as a test dep
c7e98c1
ChainRules added explicitely to runtests
8ce1498
Add comment vefore test set
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
# This file is machine-generated - editing it directly is not advised | ||
|
||
[[AbstractFFTs]] | ||
deps = ["LinearAlgebra"] | ||
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" | ||
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" | ||
version = "1.0.1" | ||
|
||
[[ArgTools]] | ||
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" | ||
|
||
[[Artifacts]] | ||
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" | ||
|
||
[[Base64]] | ||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" | ||
|
||
[[ChainRules]] | ||
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"] | ||
git-tree-sha1 = "346588c81effb94da6a30c1617e56af6a878e4d6" | ||
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" | ||
version = "1.0.1" | ||
|
||
[[ChainRulesCore]] | ||
deps = ["Compat", "LinearAlgebra", "SparseArrays"] | ||
git-tree-sha1 = "ad613c934ec3a3aa0ff19b91f15a16d56ed404b5" | ||
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
version = "1.0.2" | ||
|
||
[[CommonSubexpressions]] | ||
deps = ["MacroTools", "Test"] | ||
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" | ||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" | ||
version = "0.3.0" | ||
|
||
[[Compat]] | ||
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] | ||
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" | ||
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
version = "3.31.0" | ||
|
||
[[CompilerSupportLibraries_jll]] | ||
deps = ["Artifacts", "Libdl"] | ||
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" | ||
|
||
[[Dates]] | ||
deps = ["Printf"] | ||
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" | ||
|
||
[[DelimitedFiles]] | ||
deps = ["Mmap"] | ||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" | ||
|
||
[[DiffResults]] | ||
deps = ["StaticArrays"] | ||
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" | ||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" | ||
version = "1.0.3" | ||
|
||
[[DiffRules]] | ||
deps = ["NaNMath", "Random", "SpecialFunctions"] | ||
git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" | ||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" | ||
version = "1.0.2" | ||
|
||
[[Distributed]] | ||
deps = ["Random", "Serialization", "Sockets"] | ||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" | ||
|
||
[[DocStringExtensions]] | ||
deps = ["LibGit2"] | ||
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" | ||
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
version = "0.8.5" | ||
|
||
[[Downloads]] | ||
deps = ["ArgTools", "LibCURL", "NetworkOptions"] | ||
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" | ||
|
||
[[FillArrays]] | ||
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] | ||
git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7" | ||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
version = "0.12.1" | ||
|
||
[[ForwardDiff]] | ||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] | ||
git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6" | ||
uuid = "f6369f11-7733-5829-9624-2563aa707210" | ||
version = "0.10.18" | ||
|
||
[[IRTools]] | ||
deps = ["InteractiveUtils", "MacroTools", "Test"] | ||
git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c" | ||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df" | ||
version = "0.4.3" | ||
|
||
[[InteractiveUtils]] | ||
deps = ["Markdown"] | ||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
|
||
[[JLLWrappers]] | ||
deps = ["Preferences"] | ||
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" | ||
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" | ||
version = "1.3.0" | ||
|
||
[[LibCURL]] | ||
deps = ["LibCURL_jll", "MozillaCACerts_jll"] | ||
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" | ||
|
||
[[LibCURL_jll]] | ||
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] | ||
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" | ||
|
||
[[LibGit2]] | ||
deps = ["Base64", "NetworkOptions", "Printf", "SHA"] | ||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" | ||
|
||
[[LibSSH2_jll]] | ||
deps = ["Artifacts", "Libdl", "MbedTLS_jll"] | ||
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" | ||
|
||
[[Libdl]] | ||
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | ||
|
||
[[LinearAlgebra]] | ||
deps = ["Libdl"] | ||
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
||
[[LogExpFunctions]] | ||
deps = ["DocStringExtensions", "LinearAlgebra"] | ||
git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070" | ||
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" | ||
version = "0.2.5" | ||
|
||
[[Logging]] | ||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" | ||
|
||
[[MacroTools]] | ||
deps = ["Markdown", "Random"] | ||
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" | ||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" | ||
version = "0.5.6" | ||
|
||
[[Markdown]] | ||
deps = ["Base64"] | ||
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" | ||
|
||
[[MbedTLS_jll]] | ||
deps = ["Artifacts", "Libdl"] | ||
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" | ||
|
||
[[Mmap]] | ||
uuid = "a63ad114-7e13-5084-954f-fe012c677804" | ||
|
||
[[MozillaCACerts_jll]] | ||
uuid = "14a3606d-f60d-562e-9121-12d972cd8159" | ||
|
||
[[NaNMath]] | ||
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" | ||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" | ||
version = "0.3.5" | ||
|
||
[[NetworkOptions]] | ||
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" | ||
|
||
[[OpenSpecFun_jll]] | ||
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] | ||
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" | ||
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" | ||
version = "0.5.5+0" | ||
|
||
[[PDMats]] | ||
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] | ||
git-tree-sha1 = "4dd403333bcf0909341cfe57ec115152f937d7d8" | ||
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" | ||
version = "0.11.1" | ||
|
||
[[Pkg]] | ||
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] | ||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
|
||
[[Preferences]] | ||
deps = ["TOML"] | ||
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" | ||
uuid = "21216c6a-2e73-6563-6e65-726566657250" | ||
version = "1.2.2" | ||
|
||
[[Printf]] | ||
deps = ["Unicode"] | ||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
|
||
[[REPL]] | ||
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] | ||
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" | ||
|
||
[[Random]] | ||
deps = ["Serialization"] | ||
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
|
||
[[Requires]] | ||
deps = ["UUIDs"] | ||
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" | ||
uuid = "ae029012-a4dd-5104-9daa-d747884805df" | ||
version = "1.1.3" | ||
|
||
[[SHA]] | ||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" | ||
|
||
[[Serialization]] | ||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" | ||
|
||
[[SharedArrays]] | ||
deps = ["Distributed", "Mmap", "Random", "Serialization"] | ||
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" | ||
|
||
[[Sockets]] | ||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc" | ||
|
||
[[SparseArrays]] | ||
deps = ["LinearAlgebra", "Random"] | ||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
|
||
[[SpecialFunctions]] | ||
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] | ||
git-tree-sha1 = "508822dca004bf62e210609148511ad03ce8f1d8" | ||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
version = "1.6.0" | ||
|
||
[[StaticArrays]] | ||
deps = ["LinearAlgebra", "Random", "Statistics"] | ||
git-tree-sha1 = "5b2f81eeb66bcfe379947c500aae773c85c31033" | ||
uuid = "90137ffa-7385-5640-81b9-e52037218182" | ||
version = "1.2.8" | ||
|
||
[[Statistics]] | ||
deps = ["LinearAlgebra", "SparseArrays"] | ||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[[SuiteSparse]] | ||
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] | ||
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" | ||
|
||
[[TOML]] | ||
deps = ["Dates"] | ||
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" | ||
|
||
[[Tar]] | ||
deps = ["ArgTools", "SHA"] | ||
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" | ||
|
||
[[Test]] | ||
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] | ||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[[UUIDs]] | ||
deps = ["Random", "SHA"] | ||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" | ||
|
||
[[Unicode]] | ||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" | ||
|
||
[[Zlib_jll]] | ||
deps = ["Libdl"] | ||
uuid = "83775a58-1f1d-513f-b197-d71354ab007a" | ||
|
||
[[Zygote]] | ||
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] | ||
git-tree-sha1 = "1b6f1725e0e5b70885845329d0a599edf34dbae6" | ||
repo-rev = "master" | ||
repo-url = "https://github.com/FluxML/Zygote.jl.git" | ||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
version = "0.6.18-DEV" | ||
|
||
[[ZygoteRules]] | ||
deps = ["MacroTools"] | ||
git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" | ||
uuid = "700de1a5-db45-46bc-99cf-38207098b444" | ||
version = "0.2.1" | ||
|
||
[[nghttp2_jll]] | ||
deps = ["Artifacts", "Libdl"] | ||
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" | ||
|
||
[[p7zip_jll]] | ||
deps = ["Artifacts", "Libdl"] | ||
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
@non_differentiable validate_woodbury_arguments(A, D, S) | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(*), A::Real, B::WoodburyPDMat{T, TA, TD, TS} | ||
) where {T, TA, TD, TS} | ||
AlexRobson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
project_A = ProjectTo(A) | ||
project_B = ProjectTo(B) | ||
function times_pullback(ȳ::AbstractMatrix) | ||
Ȳ = unthunk(ȳ) | ||
Ā = dot(Ȳ, B) | ||
B̄ = A' * Ȳ | ||
return ( | ||
NoTangent(), | ||
@thunk(project_A(Ā')), | ||
@thunk(project_B(B̄)), | ||
) | ||
end | ||
AlexRobson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function times_pullback(ȳ::Tangent{<:WoodburyPDMat}) | ||
Ȳ = unthunk(ȳ) | ||
mzgubic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Ā = dot(Ȳ.A * Ȳ.D * Ȳ.A' + Ȳ.S, B) | ||
B̄ = Ȳ.A * (A' * Ȳ.D) * Ȳ.A' + A' * Ȳ.S | ||
return ( | ||
NoTangent(), | ||
@thunk(project_A(Ā')), | ||
@thunk(project_B(B̄)), | ||
) | ||
end | ||
return A * B, times_pullback | ||
end | ||
|
||
function ChainRulesCore.ProjectTo(W::WoodburyPDMat) | ||
function dW(W̄) | ||
AlexRobson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Ā(W̄) = ProjectTo(W.A)(collect((W.D * W.A' * W̄' + W.D * W.A' * W̄)')) | ||
D̄(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A) | ||
S̄(W̄) = ProjectTo(W.S)(W̄) | ||
return Tangent{typeof(W)}(; A = Ā(W̄), D = D̄(W̄), S = S̄(W̄)) | ||
end | ||
return dW | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.