Skip to content

Commit 6d1c325

Browse files
Merge pull request #730 from AayushSabharwal/as/polyadic-add
refactor: add conservative polyadic addition
2 parents 12ca5e8 + cd5a309 commit 6d1c325

File tree

1 file changed

+66
-26
lines changed

1 file changed

+66
-26
lines changed

src/types.jl

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,39 +1346,79 @@ sub_t(a,b) = promote_symtype(-, symtype(a), symtype(b))
13461346
sub_t(a) = promote_symtype(-, symtype(a))
13471347

13481348
import Base: (+), (-), (*), (//), (/), (\), (^)
1349-
function +(a::SN, b::SN)
1350-
!issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata
1351-
if isadd(a) && isadd(b)
1352-
return Add(add_t(a,b),
1353-
a.coeff + b.coeff,
1354-
_merge(+, a.dict, b.dict, filter=_iszero))
1355-
elseif isadd(a)
1356-
coeff, dict = makeadd(1, 0, b)
1357-
return Add(add_t(a,b), a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero))
1358-
elseif isadd(b)
1359-
return b + a
1360-
end
1361-
coeff, dict = makeadd(1, 0, a, b)
1362-
Add(add_t(a,b), coeff, dict)
1363-
end
1364-
1365-
function +(a::Number, b::SN)
1366-
tmp = unwrap(a)
1367-
if tmp !== a
1368-
return tmp + b
1349+
1350+
function +(a::SN, bs::SN...)
1351+
isempty(bs) && return a
1352+
# entries where `!issafecanon`
1353+
unsafes = SmallV{Any}()
1354+
# coeff and dict of the `Add`
1355+
coeff = 0
1356+
dict = sdict()
1357+
# type of the `Add`
1358+
T = symtype(a)
1359+
1360+
# handle `a` separately
1361+
if issafecanon(+, a)
1362+
if isadd(a)
1363+
coeff = a.coeff
1364+
dict = copy(a.dict)
1365+
elseif ismul(a)
1366+
v = a.coeff
1367+
a′ = Mul(symtype(a), 1, copy(a.dict); metadata = a.metadata)
1368+
dict[a′] = v
1369+
else
1370+
dict[a] = 1
1371+
end
1372+
else
1373+
push!(unsafes, a)
1374+
end
1375+
1376+
for b in bs
1377+
T = promote_symtype(+, T, symtype(b))
1378+
if !issafecanon(+, b)
1379+
push!(unsafes, b)
1380+
continue
1381+
end
1382+
if isadd(b)
1383+
coeff += b.coeff
1384+
for (k, v) in b.dict
1385+
dict[k] = get(dict, k, 0) + v
1386+
end
1387+
elseif ismul(b)
1388+
v = b.coeff
1389+
b′ = Mul(symtype(b), 1, copy(b.dict); metadata = b.metadata)
1390+
dict[b′] = get(dict, b′, 0) + v
1391+
else
1392+
dict[b] = get(dict, b, 0) + 1
1393+
end
1394+
end
1395+
# remove entries multiplied by zero
1396+
filter!(dict) do kvp
1397+
!iszero(kvp[2])
1398+
end
1399+
1400+
result = isempty(dict) ? coeff : Add(T, coeff, dict)
1401+
if !isempty(unsafes)
1402+
push!(unsafes, result)
1403+
result = Term{T}(+, unsafes)
13691404
end
1370-
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
1405+
return result
1406+
end
1407+
1408+
function +(a::Number, b::SN, bs::SN...)
1409+
b = +(b, bs...)
1410+
issafecanon(+, b) || return term(+, a, b)
13711411
iszero(a) && return b
13721412
if isadd(b)
1373-
Add(add_t(a,b), a + b.coeff, b.dict)
1413+
Add(add_t(a, b), a + b.coeff, b.dict)
13741414
else
1375-
Add(add_t(a,b), makeadd(1, a, b)...)
1415+
Add(add_t(a, b), makeadd(1, a, b)...)
13761416
end
13771417
end
13781418

1379-
+(a::SN, b::Number) = b + a
1380-
1381-
+(a::SN) = a
1419+
function +(a::SN, b::Number, bs::SN...)
1420+
return +(b, a, bs...)
1421+
end
13821422

13831423
function -(a::SN)
13841424
!issafecanon(*, a) && return term(-, a)

0 commit comments

Comments
 (0)