Skip to content

Commit 0e62078

Browse files
authored
Merge pull request #739 from JuliaDiff/ox/getfield
add frules for getfield
2 parents ec9b281 + a76e0ae commit 0e62078

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Int rather than Int64/Integer is intentional
2-
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
3-
return x.i, ẋ.i
2+
function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol})
3+
return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
4+
end
5+
6+
function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds)
7+
return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
48
end
59

610
"for a given tuple type, returns a Val{N} where N is the length of the tuple"

test/rulesets/Base/indexing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
struct FooTwoField
2+
x::Float64
3+
y::Float64
4+
end
5+
6+
7+
@testset "getfield" begin
8+
test_frule(getfield, FooTwoField(1.5, 2.5), :x, check_inferred=false)
9+
10+
test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false)
11+
test_frule(getfield, (; a=1.5, b=2.5), 2)
12+
13+
test_frule(getfield, (1.5, 2.5), 2)
14+
test_frule(getfield, (1.5, 2.5), 2, true)
15+
end
16+
117
@testset "getindex" begin
218
@testset "getindex(::Tuple, ...)" begin
319
x = (1.2, 3.4, 5.6)

0 commit comments

Comments
 (0)