Skip to content

Commit cd54f7f

Browse files
committed
Assorted fixes and return type adjustments
1 parent 920f3cf commit cd54f7f

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

lib/DaggerGraphs/src/dgraph.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,17 @@ end
279279
nparts(g::DGraph) = with_state(g, nparts)
280280
nparts(g::DGraphState) = length(g.parts)
281281
Base.eltype(::DGraph{T}) where T = T
282-
Graphs.edgetype(::DGraph{T}) where T = Tuple{T,T}
282+
Graphs.edgetype(::DGraph{T}) where T = Edge{T}
283283
Graphs.nv(g::DGraph) = with_state(g, nv)::Int
284284
function Graphs.nv(g::DGraphState)
285285
if !isempty(g.parts_nv)
286-
return last(g.parts_nv).stop
286+
return Int(last(g.parts_nv).stop)
287287
else
288288
return 0
289289
end
290290
end
291291
Graphs.ne(g::DGraph) = with_state(g, ne)::Int
292-
Graphs.ne(g::DGraphState) = sum(g.parts_ne; init=0) + sum(g.bg_adjs_ne_src; init=0)
292+
Graphs.ne(g::DGraphState) = Int(sum(g.parts_ne; init=0) + sum(g.bg_adjs_ne_src; init=0))
293293
Graphs.has_vertex(g::DGraph, v::Integer) = 1 <= v <= nv(g)
294294
Graphs.has_edge(g::DGraph, edge::Tuple) = has_edge(g, edge[1], edge[2])
295295
Graphs.has_edge(g::DGraph, src::Integer, dst::Integer) =
@@ -303,15 +303,17 @@ function Graphs.has_edge(g::DGraphState{T,D}, src::Integer, dst::Integer) where
303303
if src_part_idx == dst_part_idx
304304
# The edge will be within a graph partition
305305
part = g.parts[src_part_idx]
306-
return exec_fast(has_edge, part, src, dst)
306+
src_shift = src - (g.parts_nv[src_part_idx].start - 1)
307+
dst_shift = dst - (g.parts_nv[dst_part_idx].start - 1)
308+
return exec_fast(has_edge, part, src_shift, dst_shift)
307309
else
308310
# The edge will be in an AdjList
309311
adj = g.bg_adjs[src_part_idx]
310312
return exec_fast(has_edge, adj, src, dst)
311313
end
312314
end
313315
Graphs.is_directed(::DGraph{T,D}) where {T,D} = D
314-
Graphs.vertices(g::DGraph) = Base.OneTo(nv(g))
316+
Graphs.vertices(g::DGraph{T}) where T = Base.OneTo{T}(nv(g))
315317
Graphs.edges(g::DGraph) = DGraphEdgeIter(g)
316318
edges_with_metadata(f, g::DGraph) = DGraphEdgeIter(g; metadata=true, meta_f=f)
317319
edges_with_weights(g::DGraph) = edges_with_metadata(weights, g)
@@ -357,8 +359,8 @@ function add_partition!(g::DGraphState{T,D}, n::Integer) where {T,D}
357359
if n < 1
358360
throw(ArgumentError("n must be >= 1"))
359361
end
360-
push!(g.parts, Dagger.spawn(n) do n
361-
D ? SimpleDiGraph(n) : SimpleGraph(n)
362+
push!(g.parts, Dagger.spawn(T, n) do T, n
363+
D ? SimpleDiGraph{T}(n) : SimpleGraph{T}(n)
362364
end)
363365
num_v = nv(g)
364366
push!(g.parts_nv, (num_v+1):(num_v+n))
@@ -451,8 +453,8 @@ function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where
451453
# Edge spans two partitions
452454
src_bg_adj = g.bg_adjs[src_part_idx]
453455
dst_bg_adj = g.bg_adjs[dst_part_idx]
454-
src_t = exec_fast(add_edge!, src_bg_adj, (src, dst); fetch=false)
455-
dst_t = exec_fast(add_edge!, dst_bg_adj, (src, dst); fetch=false)
456+
src_t = exec_fast_nofetch(add_edge!, src_bg_adj, (src, dst))
457+
dst_t = exec_fast_nofetch(add_edge!, dst_bg_adj, (src, dst))
456458
if !fetch(src_t) || !fetch(dst_t)
457459
return false
458460
end
@@ -529,13 +531,13 @@ end
529531
edge_owner(src::Int, dst::Int, src_part_idx::Int, dst_part_idx::Int) =
530532
iseven(hash(Base.unsafe_trunc(UInt, src+dst))) ? src_part_idx : dst_part_idx
531533
Graphs.inneighbors(g::DGraph, v::Integer) = with_state(g, inneighbors, v)
532-
function Graphs.inneighbors(g::DGraphState, v::Integer)
534+
function Graphs.inneighbors(g::DGraphState{T}, v::Integer) where T
533535
part_idx = findfirst(span->v in span, g.parts_nv)
534536
if part_idx === nothing
535537
throw(BoundsError(g, v))
536538
end
537539

538-
neighbors = Int[]
540+
neighbors = T[]
539541
shift = g.parts_nv[part_idx].start - 1
540542

541543
# Check against local edges
@@ -549,13 +551,13 @@ function Graphs.inneighbors(g::DGraphState, v::Integer)
549551
return neighbors
550552
end
551553
Graphs.outneighbors(g::DGraph, v::Integer) = with_state(g, outneighbors, v)
552-
function Graphs.outneighbors(g::DGraphState, v::Integer)
554+
function Graphs.outneighbors(g::DGraphState{T}, v::Integer) where T
553555
part_idx = findfirst(span->v in span, g.parts_nv)
554556
if part_idx === nothing
555557
throw(BoundsError(g, v))
556558
end
557559

558-
neighbors = Int[]
560+
neighbors = T[]
559561
shift = g.parts_nv[part_idx].start - 1
560562

561563
# Check against local edges

0 commit comments

Comments
 (0)