Skip to content

Commit 74756eb

Browse files
Merge pull request #1384 from devmotion/dw/hessian_sparsity
Assume unknown functions are non-linear in `hessian_sparsity`
2 parents ce8b3f6 + 3185d30 commit 74756eb

File tree

3 files changed

+123
-38
lines changed

3 files changed

+123
-38
lines changed

src/diff.jl

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -646,24 +646,13 @@ let
646646
linearity_rules = [
647647
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
648648
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
649-
@rule (~f)(~x::(!isidx)) => _scalar
650649

651-
@rule (~f)(~x::isidx) => if haslinearity_1(~f)
652-
combine_terms_1(linearity_1(~f), ~x)
653-
else
654-
error("Function of unknown linearity used: ", ~f)
655-
end
650+
@rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar
656651
@rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x)
657-
@rule (~f)(~x, ~y) => begin
658-
if haslinearity_2(~f)
659-
a = isidx(~x) ? ~x : _scalar
660-
b = isidx(~y) ? ~y : _scalar
661-
combine_terms_2(linearity_2(~f), a, b)
662-
else
663-
error("Function of unknown linearity used: ", ~f)
664-
end
665-
end
666-
@rule ~x::issym => 0]
652+
@rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar)
653+
654+
@rule ~x::issym => 0
655+
]
667656
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm))
668657

669658
global hessian_sparsity
@@ -696,9 +685,8 @@ let
696685
@assert !(expr isa AbstractArray)
697686
expr = value(expr)
698687
u = map(value, vars)
699-
idx(i) = TermCombination(Set([Dict(i=>1)]))
700-
dict = Dict(u .=> idx.(1:length(u)))
701-
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; maketerm=basic_mkterm)(expr)
688+
dict = Dict(ui => TermCombination(Set([Dict(i=>1)])) for (i, ui) in enumerate(u))
689+
f = Rewriters.Prewalk(x-> get(dict, x, x); maketerm=basic_mkterm)(expr)
702690
lp = linearity_propagator(f)
703691
S = _sparse(lp, length(u))
704692
S = full ? S : tril(S)

src/linearity.jl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,45 @@
11
using SpecialFunctions
22
import Base.Broadcast
33

4-
5-
const linearity_known_1 = IdDict{Function,Bool}()
6-
const linearity_known_2 = IdDict{Function,Bool}()
7-
84
const linearity_map_1 = IdDict{Function, Bool}()
95
const linearity_map_2 = IdDict{Function, Tuple{Bool, Bool, Bool}}()
106

117
# 1-arg
12-
138
const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]
149

1510
const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh, slog, ssqrt, scbrt]
1611

17-
# We store 3 bools even for 1-arg functions for type stability
18-
const three_trues = (true, true, true)
1912
for f in monadic_linear
20-
linearity_known_1[f] = true
2113
linearity_map_1[f] = true
2214
end
2315

2416
for f in monadic_nonlinear
25-
linearity_known_1[f] = true
2617
linearity_map_1[f] = false
2718
end
2819

2920
# 2-arg
3021
for f in [+, rem2pi, -, >, isless, <, isequal, max, min, convert, <=, >=]
31-
linearity_known_2[f] = true
3222
linearity_map_2[f] = (true, true, true)
3323
end
3424

3525
for f in [*]
36-
linearity_known_2[f] = true
3726
linearity_map_2[f] = (true, true, false)
3827
end
3928

4029
for f in [/]
41-
linearity_known_2[f] = true
4230
linearity_map_2[f] = (true, false, false)
4331
end
4432
for f in [\]
45-
linearity_known_2[f] = true
4633
linearity_map_2[f] = (false, true, false)
4734
end
4835

4936
for f in [hypot, atan, mod, rem, lbeta, ^, beta]
50-
linearity_known_2[f] = true
5137
linearity_map_2[f] = (false, false, false)
5238
end
5339

54-
haslinearity_1(@nospecialize(f)) = get(linearity_known_1, f, false)
55-
haslinearity_2(@nospecialize(f)) = get(linearity_known_2, f, false)
56-
57-
linearity_1(@nospecialize(f)) = linearity_map_1[f]
58-
linearity_2(@nospecialize(f)) = linearity_map_2[f]
40+
# Fallback assumption: Function is not linear, i.e., derivatives are non-zero
41+
linearity_1(@nospecialize(f)) = get(linearity_map_1, f, false)
42+
linearity_2(@nospecialize(f)) = get(linearity_map_2, f, (false, false, false))
5943

6044
# TermCombination datastructure
6145

test/diff.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,116 @@ let
407407
@test isequal(expand_derivatives(D(Symbolics.scbrt(1 + x ^ 2))), simplify((2x) / (3Symbolics.scbrt(1 + x^2)^2)))
408408
@test isequal(expand_derivatives(D(Symbolics.slog(1 + x ^ 2))), simplify((2x) / (1 + x ^ 2)))
409409
end
410+
411+
# Hessian sparsity involving unknown functions
412+
let
413+
@variables x₁ x₂ p q[1:1]
414+
expr = 3x₁^2 + 4x₁ * x₂
415+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
416+
417+
expr = 3x₁^2 + 4x₁ * x₂ + p
418+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
419+
420+
# issue 643: example test2_num
421+
expr = 3x₁^2 + 4x₁ * x₂ + q[1]
422+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
423+
424+
# Custom function: By default assumed to be non-linear
425+
myexp(x) = exp(x)
426+
@register_symbolic myexp(x)
427+
expr = 3x₁^2 + 4x₁ * x₂ + myexp(p)
428+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
429+
expr = 3x₁^2 + 4x₁ * x₂ + myexp(x₂)
430+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
431+
432+
mylogaddexp(x, y) = log(exp(x) + exp(y))
433+
@register_symbolic mylogaddexp(x, y)
434+
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, 2)
435+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
436+
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(3, p)
437+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
438+
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, 2)
439+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
440+
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, q[1])
441+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
442+
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(p, x₂)
443+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
444+
expr = 3x₁^2 + 4x₁ * x₂ + mylogaddexp(x₂, 4)
445+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
446+
447+
# Custom linear function: Possible to extend `Symbolics.linearity_1`/`Symbolics.linearity_2`
448+
myidentity(x) = x
449+
@register_symbolic myidentity(x)
450+
Symbolics.linearity_1(::typeof(myidentity)) = true
451+
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(p)
452+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
453+
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(q[1])
454+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
455+
expr = 3x₁^2 + 4x₁ * x₂ + myidentity(x₂)
456+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
457+
458+
mymul1plog(x, y) = x * (1 + log(y))
459+
@register_symbolic mymul1plog(x, y)
460+
Symbolics.linearity_2(::typeof(mymul1plog)) = (true, false, false)
461+
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(p, q[1])
462+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
463+
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(x₂, q[1])
464+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true false]
465+
expr = 3x₁^2 + 4x₁ * x₂ + mymul1plog(q[1], x₂)
466+
@test Matrix(Symbolics.hessian_sparsity(expr, [x₁, x₂])) == [true true; true true]
467+
end
468+
469+
# issue #555
470+
let
471+
# first example
472+
@variables p[1:1] x[1:1]
473+
p = collect(p)
474+
x = collect(x)
475+
@test collect(Symbolics.sparsehessian(p[1] * x[1], x)) == [0;;]
476+
@test isequal(collect(Symbolics.sparsehessian(p[1] * x[1]^2, x)), [2p[1];;])
477+
478+
# second example
479+
@variables a[1:2]
480+
a = collect(a)
481+
ex = (a[1]+a[2])^2
482+
@test Symbolics.hessian(ex, [a[1]]) == [2;;]
483+
@test collect(Symbolics.sparsehessian(ex, [a[1]])) == [2;;]
484+
@test collect(Symbolics.sparsehessian(ex, a)) == fill(2, 2, 2)
485+
end
486+
487+
# issue #847
488+
let
489+
@variables x[1:2] y[1:2]
490+
x = Symbolics.scalarize(x)
491+
y = Symbolics.scalarize(y)
492+
493+
z = (x[1] + x[2]) * (y[1] + y[2])
494+
@test Symbolics.islinear(z, x)
495+
@test Symbolics.isaffine(z, x)
496+
497+
z = (x[1] + x[2])
498+
@test Symbolics.islinear(z, x)
499+
@test Symbolics.isaffine(z, x)
500+
end
501+
502+
# issue #790
503+
let
504+
c(x) = [sum(x) - 1]
505+
@variables xs[1:2] ys[1:1]
506+
w = Symbolics.scalarize(xs)
507+
v = Symbolics.scalarize(ys)
508+
expr = dot(v, c(w))
509+
@test !Symbolics.islinear(expr, w)
510+
@test Symbolics.isaffine(expr, w)
511+
@test collect(Symbolics.hessian_sparsity(expr, w)) == fill(false, 2, 2)
512+
end
513+
514+
# issue #749
515+
let
516+
@variables x y
517+
@register_symbolic Base.FastMath.exp_fast(x, y)
518+
expr = Base.FastMath.exp_fast(x, y)
519+
@test !Symbolics.islinear(expr, [x, y])
520+
@test !Symbolics.isaffine(expr, [x, y])
521+
@test collect(Symbolics.hessian_sparsity(expr, [x, y])) == fill(true, 2, 2)
522+
end

0 commit comments

Comments
 (0)