Skip to content

Commit ace794f

Browse files
Pietro Vertechilcw
andauthored
make compatible with CUDA kernels (#114)
* Avoid dynamic code in get_ith This uses a generated function to avoid a dynamic call for `getindex` allowing it to be called in a CUDAnative kernel. * Make foreachfield a static function This allows `setindex!` on `StructArray`s to be used in `CUDAnative` kernels. * Use Adapt for converting CPU to GPU StructArrays * make getindex generated and revert foreach changes * simpler foreachfield * restore foreachfield for an arbitrary number of variables * remove generated get_ith * test adapt * remove extra file * fix rebasing issues Co-authored-by: Lucas C Wilcox <lucas@swirlee.com>
1 parent 00e42b1 commit ace794f

File tree

5 files changed

+34
-10
lines changed

5 files changed

+34
-10
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@ uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
33
version = "0.4.2"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
78
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
89

910
[compat]
11+
Adapt = "1"
1012
DataAPI = "1"
1113
Tables = "1"
1214
julia = "1"
1315

1416
[extras]
17+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1518
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1619
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
1720
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1821
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
1922

2023
[targets]
21-
test = ["Test", "OffsetArrays", "PooledArrays", "WeakRefStrings"]
24+
test = ["Test", "GPUArrays", "OffsetArrays", "PooledArrays", "WeakRefStrings"]

src/StructArrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,8 @@ include("groupjoin.jl")
1616
include("lazy.jl")
1717
include("tables.jl")
1818

19+
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays
20+
import Adapt
21+
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)
22+
1923
end # module

src/structarray.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,11 @@ Base.axes(s::StructArray) = axes(fieldarrays(s)[1])
134134
Base.axes(s::StructArray{<:Any, <:Any, <:EmptyTup}) = (1:0,)
135135

136136
get_ith(cols::NamedTuple, I...) = get_ith(Tuple(cols), I...)
137-
function get_ith(cols::NTuple{N, Any}, I...) where N
138-
ntuple(N) do i
139-
@inbounds res = getfield(cols, i)[I...]
140-
return res
141-
end
137+
function get_ith(cols::Tuple, I...)
138+
@inbounds r = first(cols)[I...]
139+
return (r, get_ith(Base.tail(cols), I...)...)
142140
end
141+
get_ith(::Tuple{}, I...) = ()
143142

144143
Base.@propagate_inbounds function Base.getindex(x::StructArray{T, <:Any, <:Any, CartesianIndex{N}}, I::Vararg{Int, N}) where {T, N}
145144
cols = fieldarrays(x)

src/utils.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,25 @@ else
2828
const _getproperty = getproperty
2929
end
3030

31-
function _foreachfield(names, xs)
31+
function _foreachfield(names, L)
32+
vars = ntuple(i -> gensym(), L)
3233
exprs = Expr[]
34+
for (i, v) in enumerate(vars)
35+
push!(exprs, Expr(:(=), v, Expr(:call, :getfield, :xs, i)))
36+
end
3337
for field in names
3438
sym = QuoteNode(field)
35-
args = [Expr(:call, :_getproperty, :(getfield(xs, $j)), sym) for j in 1:length(xs)]
39+
args = [Expr(:call, :_getproperty, var, sym) for var in vars]
3640
push!(exprs, Expr(:call, :f, args...))
3741
end
3842
push!(exprs, :(return nothing))
3943
return Expr(:block, exprs...)
4044
end
4145

42-
@generated foreachfield(::Type{<:NamedTuple{names}}, f, xs...) where {names} = _foreachfield(names, xs)
43-
@generated foreachfield(::Type{<:NTuple{N, Any}}, f, xs...) where {N} = _foreachfield(Base.OneTo(N), xs)
46+
@generated foreachfield(::Type{<:NamedTuple{names}}, f, xs::Vararg{Any, L}) where {names, L} =
47+
_foreachfield(names, L)
48+
@generated foreachfield(::Type{<:NTuple{N, Any}}, f, xs::Vararg{Any, L}) where {N, L} =
49+
_foreachfield(Base.OneTo(N), L)
4450

4551
foreachfield(f, x::T, xs...) where {T} = foreachfield(staticschema(T), f, x, xs...)
4652

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using StructArrays: staticschema, iscompatible, _promote_typejoin, append!!
33
using OffsetArrays: OffsetArray
44
import Tables, PooledArrays, WeakRefStrings
55
using DataAPI: refarray, refvalue
6+
using Adapt: adapt
7+
import GPUArrays
68
using Test
79

810
@testset "index" begin
@@ -700,3 +702,13 @@ end
700702
@test vcat(dest, StructVector(makeitr())) == append!!(copy(dest), makeitr())
701703
end
702704
end
705+
706+
@testset "adapt" begin
707+
s = StructArray(a = 1:10, b = StructArray(c = 1:10, d = 1:10))
708+
t = adapt(Array, s)
709+
@test propertynames(t) == (:a, :b)
710+
@test s == t
711+
@test t.a isa Array
712+
@test t.b.c isa Array
713+
@test t.b.d isa Array
714+
end

0 commit comments

Comments
 (0)