Skip to content

Commit 61bda17

Browse files
Merge pull request #449 from FluxML/cl/gather2
add `gather(src, IJK...)`
2 parents 0b64dc1 + 40b2848 commit 61bda17

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/gather.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ or multiple `dst` columns.
5454
See [`gather!`](@ref) for an in-place version.
5555
5656
# Examples
57+
5758
```jldoctest
5859
julia> NNlib.gather([1,20,300,4000], [2,4,2])
5960
3-element Vector{Int64}:
@@ -83,5 +84,38 @@ function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::A
8384
y = gather!(dst, src, idx)
8485
src_size = size(src)
8586
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
86-
y, gather!_pullback
87+
return y, gather!_pullback
88+
end
89+
90+
"""
91+
gather(src, IJK...)
92+
93+
Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and
94+
call `gather` on it: `gather(src, CartesianIndex.(IJK...))`.
95+
96+
# Examples
97+
98+
```jldoctest
99+
julia> src = reshape([1:15;], 3, 5)
100+
3×5 Matrix{Int64}:
101+
1 4 7 10 13
102+
2 5 8 11 14
103+
3 6 9 12 15
104+
105+
julia> NNlib.gather(src, [1, 2], [2, 4])
106+
2-element Vector{Int64}:
107+
4
108+
11
109+
```
110+
"""
111+
function gather(src::AbstractArray{Tsrc, Nsrc},
112+
I::AbstractVector{<:Integer},
113+
J::AbstractVector{<:Integer},
114+
Ks::AbstractVector{<:Integer}...) where {Nsrc, Tsrc}
115+
116+
return gather(src, to_cartesian_index(I, J, Ks...))
87117
end
118+
119+
to_cartesian_index(IJK...) = CartesianIndex.(IJK...)
120+
121+
@non_differentiable to_cartesian_index(::Any...)

test/gather.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,14 @@ end
149149
gradtest(xs -> gather!(dst, xs, index), src)
150150
gradtest(xs -> gather(xs, index), src)
151151
end
152+
153+
@testset "gather(src, IJK...)" begin
154+
x = reshape([1:15;], 3, 5)
155+
156+
y = gather(x, [1,2], [2,4])
157+
@test y == [4, 11]
158+
159+
@test gather(x, [1, 2]) == [1 4
160+
2 5
161+
3 6]
162+
end

0 commit comments

Comments
 (0)