Skip to content

Commit db1eb6b

Browse files
authored
Define getindex on rules (#57)
We define `iterate` on them for convenience, and `getindex` provides a similar convenience for cases where you're not sure whether the rules resulting from a call to `rrule`/`frule` will be a `Tuple`. So instead of writing `partials isa Tuple ? partials[i] : partials`, you can now just write `partials[i]`.
1 parent 314b08a commit db1eb6b

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/rules.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ abstract type AbstractRule end
5454
Base.iterate(rule::AbstractRule) = (rule, nothing)
5555
Base.iterate(::AbstractRule, ::Any) = nothing
5656

57+
# This ensures we don't need to check whether the result of `rrule`/`frule` is a tuple
58+
# in order to get the `i`th rule (assuming it's 1)
59+
Base.getindex(rule::AbstractRule, i::Integer) = i == 1 ? rule : throw(BoundsError())
60+
5761
"""
5862
accumulate(Δ, rule::AbstractRule, args...)
5963

test/rules.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@ cool(x) = x + 1
1212
@test rrx == 2
1313
@test rr(1) == 1
1414
end
15-
@testset "iterating rules" begin
15+
@testset "iterating and indexing rules" begin
1616
_, rule = frule(+, 1)
1717
i = 0
1818
for r in rule
1919
@test r === rule
2020
i += 1
2121
end
2222
@test i == 1 # rules only iterate once, yielding themselves
23+
@test rule[1] == rule
24+
@test_throws BoundsError rule[2]
2325
end
2426
@testset "helper functions" begin
2527
# Hits fallback, since we can't update `Diagonal`s in place

0 commit comments

Comments
 (0)