Skip to content

Commit d6f8138

Browse files
authored
Merge pull request #2 from JuliaMath/teh/sumsquares
Check and fix the `sumsquares` demo
2 parents 314697a + 7622839 commit d6f8138

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

src/CheckedArithmetic.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,16 @@ macro check(expr)
106106
for i = 2:length(expr.args)
107107
safeexpr.args[i] = Expr(:call, :(CheckedArithmetic.safearg), expr.args[i])
108108
end
109-
return esc(quote
110-
val = $expr
111-
valcmp = CheckedArithmetic.safeconvert(typeof(val), $safeexpr)
112-
@test val == valcmp
113-
return val
114-
end)
109+
return quote
110+
local val = $(esc(expr))
111+
local valcmp = CheckedArithmetic.safeconvert(typeof(val), $(esc(safeexpr)))
112+
if ismissing(val) && ismissing(valcmp)
113+
val
114+
else
115+
val == valcmp || error(val, " is not equal to ", valcmp)
116+
val
117+
end
118+
end
115119
end
116120

117121
"""
@@ -172,12 +176,12 @@ safearg(ref::Ref) = Ref(safearg(ref[]))
172176
safearg(d::Dict) = Dict(safearg(p) for p in d)
173177
safearg(d::Base.EnvDict) = d
174178
safearg(d::Base.ImmutableDict) = Base.ImmutableDict(safearg(p) for p in d)
175-
safearg(d::Iterators.Pairs) = Iterators.Pairs(safearg(p) for p in d)
179+
safearg(d::Iterators.Pairs) = Iterators.Pairs(safearg(d.data), d.itr)
176180
safearg(d::IdDict) = IdDict(safearg(p) for p in d)
177181
safearg(d::WeakKeyDict) = WeakKeyDict(k=>safearg(v) for (k, v) in d) # do not convert keys
178182
## AbstractSets
179-
safearg(s::Set) = Set(f(x) for x in s)
180-
safearg(s::Base.IdSet) = Base.IdSet(f(x) for x in s)
183+
safearg(s::Set) = Set(safearg(x) for x in s)
184+
safearg(s::Base.IdSet) = Base.IdSet(safearg(x) for x in s)
181185
safearg(s::BitSet) = s
182186

183187
# Other common types
@@ -243,6 +247,7 @@ Base.@pure accumulatortype(op::Function, T1::Type, T2::Type, T3::Type...) =
243247
accumulatortype(op, promote_type(T1, T2, T3...))
244248
Base.@pure accumulatortype(T1::Type, T2::Type, T3::Type...) =
245249
accumulatortype(*, T1, T2, T3...)
250+
accumulatortype(::Type{T}) where T = accumulatortype(*, T)
246251

247252
const SignPreserving = Union{typeof(+), typeof(*)}
248253
const ArithmeticOp = Union{SignPreserving,typeof(-)}

test/runtests.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using CheckedArithmetic
2-
using Test
2+
using Test, Dates
33

44
@test isempty(detect_ambiguities(CheckedArithmetic, Base, Core))
55

@@ -8,6 +8,14 @@ using Test
88
minus(x, y) = x - y
99
end
1010

11+
function sumsquares(A::AbstractArray)
12+
s = zero(accumulatortype(eltype(A)))
13+
for a in A
14+
s += acc(a)^2
15+
end
16+
return s
17+
end
18+
1119
@testset "CheckedArithmetic.jl" begin
1220
@testset "@checked" begin
1321
@test @checked(abs(Int8(-2))) === Int8(2)
@@ -41,7 +49,7 @@ end
4149
@test_throws OverflowError minus(0x20, 0x30)
4250
end
4351

44-
@testset "check" begin
52+
@testset "@check" begin
4553
@test @check(3+5) == 8
4654
@test_throws InexactError @check(0xf0+0x15)
4755
@test @check([3]+[5]) == [8]
@@ -54,10 +62,10 @@ end
5462
@test @check(times2(Dict("a"=>7))) == Dict("a"=>14)
5563
@test_throws InexactError @check(times2(Dict("a"=>0xf0)))
5664
for item in Any["hi", :hi, (3,), (silly="hi",), trues(3), [true], 1:3, 1:2:5,
57-
LinRange(1, 3, 3), StepRangeLen(1.0, 3.0, 3), 0x01:0x03, Ref(2),
65+
LinRange(1, 3, 3), StepRangeLen(1.0, 3.0, 3), 0x01:0x03,
5866
pairs((silly="hi",)), Set([1,3]), BitSet(7), nothing, missing,
5967
Some(nothing), 'c', MIME("text/plain"), IOBuffer(), r"\d+",
60-
Channel(), CartesianIndex(1, 3), Base.UUID(0), `ls`,
68+
Channel(7), CartesianIndex(1, 3), Base.UUID(0), `ls`,
6169
sum, Base, Val(3), Task(()->1), Base.Order.Forward, Timer(0.1),
6270
now()]
6371
@test @check(identity(item)) === item
@@ -69,5 +77,7 @@ end
6977
@test acc(-, 0x02) === 2
7078
@test accumulatortype(-, UInt8) === Int
7179
@test accumulatortype(*, Int16, Float16) === Float64
80+
81+
@test sumsquares(0x00:0xff) == sumsquares(0:255)
7282
end
7383
end

0 commit comments

Comments
 (0)