Skip to content

Commit 8c78436

Browse files
authored
add test for #43 (#48)
1 parent b03feac commit 8c78436

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ julia = "1.7"
2323

2424
[extras]
2525
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
26+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2627
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2728
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2829
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2930

3031
[targets]
31-
test = ["Test", "Symbolics", "ForwardDiff", "Random"]
32+
test = ["Test", "ForwardDiff", "LinearAlgebra", "Random", "Symbolics"]

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ChainRules
44
using ChainRulesCore
55
using ChainRules: ZeroTangent, NoTangent
66
using Symbolics
7+
using LinearAlgebra
78

89
using Test
910

@@ -195,4 +196,9 @@ end
195196
# Issue #40 - Symbol type parameters not properly quoted
196197
@test Diffractor.∂⃖recurse{1}()(Val{:transformations})[1] === Val{:transformations}()
197198

199+
# PR #43
200+
loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)
201+
x = rand(10, 10)
202+
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}}
203+
198204
include("pinn.jl")

0 commit comments

Comments
 (0)