Skip to content

Commit bea1b81

Browse files
committed
add frules for getfield
1 parent b171c09 commit bea1b81

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Int rather than Int64/Integer is intentional
2-
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
3-
return x.i, ẋ.i
2+
function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol})
3+
return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
4+
end
5+
6+
function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds)
7+
return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
48
end
59

610
"for a given tuple type, returns a Val{N} where N is the length of the tuple"
@@ -21,7 +25,6 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu
2125
dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T))
2226
return (NoTangent(), Tangent{T}(dx...), NoTangent())
2327
end
24-
return x[i], getindex_back_2
2528
end
2629

2730
# Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector),

test/rulesets/Base/indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
@testset "getfield" begin
2+
struct Foo
3+
x::Float64
4+
y::Float64
5+
end
6+
test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false)
7+
8+
test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false)
9+
test_frule(getfield, (; a=1.5, b=2.5), 2)
10+
11+
test_frule(getfield, (1.5, 2.5), 2)
12+
test_frule(getfield, (1.5, 2.5), 2, true)
13+
end
14+
115
@testset "getindex" begin
216
@testset "getindex(::Tuple, ...)" begin
317
x = (1.2, 3.4, 5.6)

0 commit comments

Comments
 (0)