@@ -54,6 +54,7 @@ or multiple `dst` columns.
54
54
See [`gather!`](@ref) for an in-place version.
55
55
56
56
# Examples
57
+
57
58
```jldoctest
58
59
julia> NNlib.gather([1,20,300,4000], [2,4,2])
59
60
3-element Vector{Int64}:
@@ -83,5 +84,38 @@ function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::A
83
84
y = gather! (dst, src, idx)
84
85
src_size = size (src)
85
86
gather!_pullback (Δ) = (NoTangent (), NoTangent (), ∇gather_src (unthunk (Δ), src_size, idx), NoTangent ())
86
- y, gather!_pullback
87
+ return y, gather!_pullback
88
+ end
89
+
90
+ """
91
+ gather(src, IJK...)
92
+
93
+ Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and
94
+ call `gather` on it: `gather(src, CartesianIndex.(IJK...))`.
95
+
96
+ # Examples
97
+
98
+ ```jldoctest
99
+ julia> src = reshape([1:15;], 3, 5)
100
+ 3×5 Matrix{Int64}:
101
+ 1 4 7 10 13
102
+ 2 5 8 11 14
103
+ 3 6 9 12 15
104
+
105
+ julia> NNlib.gather(src, [1, 2], [2, 4])
106
+ 2-element Vector{Int64}:
107
+ 4
108
+ 11
109
+ ```
110
+ """
111
+ function gather (src:: AbstractArray{Tsrc, Nsrc} ,
112
+ I:: AbstractVector{<:Integer} ,
113
+ J:: AbstractVector{<:Integer} ,
114
+ Ks:: AbstractVector{<:Integer} ...) where {Nsrc, Tsrc}
115
+
116
+ return gather (src, to_cartesian_index (I, J, Ks... ))
87
117
end
118
+
119
+ to_cartesian_index (IJK... ) = CartesianIndex .(IJK... )
120
+
121
+ @non_differentiable to_cartesian_index (:: Any... )
0 commit comments