Skip to content

Commit e4b06dc

Browse files
Merge pull request #992 from simeonschaub/offset-idx
properly support offset indices
2 parents a270d66 + b81e67d commit e4b06dc

File tree

6 files changed

+34
-9
lines changed

6 files changed

+34
-9
lines changed

src/array-lib.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ function Base.getindex(x::SymArray, idx...)
2222
meta = metadata(unwrap(x))
2323
if shape(x) !== Unknown() && all(i -> i isa Integer, idx)
2424
II = CartesianIndices(axes(x))
25+
ii = CartesianIndex(idx)
2526
@boundscheck begin
26-
if !checkbounds(Bool, II, idx...)
27+
if !in(ii, II)
2728
throw(BoundsError(x, idx))
2829
end
2930
end
30-
ii = II[idx...]
3131
res = Term{eltype(symtype(x))}(getindex, [x, Tuple(ii)...]; metadata = meta)
3232
elseif all(i -> symtype(i) <: Integer, idx)
3333
shape(x) !== Unknown() && @boundscheck begin

src/arrays.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ function Base.show(io::IO, aop::ArrayOp)
8181
end
8282
end
8383

84+
Base.summary(io::IO, aop::ArrayOp) = Base.array_summary(io, aop, shape(aop))
85+
function Base.showarg(io::IO, aop::ArrayOp, toplevel)
86+
show(io, aop)
87+
toplevel && print(io, "::", typeof(aop))
88+
return nothing
89+
end
90+
8491
symtype(a::ArrayOp{T}) where {T} = T
8592
istree(a::ArrayOp) = true
8693
function operation(a::ArrayOp)
@@ -208,7 +215,9 @@ function make_shape(output_idx, expr, ranges=Dict())
208215
end
209216
mi = matches[i]
210217
@assert !isempty(mi)
211-
return get_extents(mi)
218+
ext = get_extents(mi)
219+
ext isa Unknown && return Unknown()
220+
return Base.OneTo(length(ext))
212221
elseif i isa Integer
213222
return Base.OneTo(1)
214223
end
@@ -526,7 +535,7 @@ end
526535
function axes(A::Union{Arr, SymArray})
527536
s = shape(unwrap(A))
528537
s === Unknown() && error("axes of $A not known")
529-
return map(x->1:length(x), s)
538+
return s
530539
end
531540

532541

test/arrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,19 @@ end
393393
@test !isequal(a, b) && !isequal(b, c) && !isequal(a, c)
394394
@test hash(a) != hash(b) && hash(b) != hash(c) && hash(a) != hash(c)
395395
end
396+
397+
@testset "Offset Indices" begin
398+
@variables k[0:3]
399+
400+
@testset "i = $i" for i in 0:3
401+
sym = unwrap(k[i])
402+
@test operation(sym) === getindex
403+
args = arguments(sym)
404+
@test length(args) == 2
405+
@test args[1] === unwrap(k)
406+
@test args[2] === i
407+
end
408+
409+
@test_throws BoundsError k[-1]
410+
@test_throws BoundsError k[4]
411+
end

test/build_function_tests/stencil-broadcast-inplace.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
ˍ₋out_1 = (view)(ˍ₋out, 1:6, 1:6)
66
ˍ₋out_1 .= 0
77
ˍ₋out_2 = (view)(ˍ₋out, 2:5, 2:5)
8-
for (j, j′) = zip(1:4, reset_to_one(1:4))
9-
for (i, i′) = zip(1:4, reset_to_one(1:4))
8+
for (j, j′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
9+
for (i, i′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
1010
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (+)(1, (getindex)(ˍ₋out_2_input_1, i, j)))
1111
end
1212
end

test/build_function_tests/stencil-broadcast-outplace.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
ˍ₋out_1 = (view)(ˍ₋out, 1:6, 1:6)
55
ˍ₋out_1 .= 0
66
ˍ₋out_2 = (view)(ˍ₋out, 2:5, 2:5)
7-
for (j, j′) = zip(1:4, reset_to_one(1:4))
8-
for (i, i′) = zip(1:4, reset_to_one(1:4))
7+
for (j, j′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
8+
for (i, i′) = zip(Base.OneTo(4), reset_to_one(Base.OneTo(4)))
99
ˍ₋out_2[i′, j′] = (+)(ˍ₋out_2[i′, j′], (+)(1, (getindex)(ˍ₋out_2_input_1, i, j)))
1010
end
1111
end

test/build_function_tests/transpose-inplace.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
:(function (x,)
2-
let ˍ₋out = zeros(Float64, map(length, (1:4, 1:4)))
2+
let ˍ₋out = zeros(Float64, map(length, (Base.OneTo(4), Base.OneTo(4))))
33
begin
44
for (j, j′) = zip(1:4, reset_to_one(1:4))
55
for (i, i′) = zip(1:4, reset_to_one(1:4))

0 commit comments

Comments
 (0)