Skip to content

Commit 69cf971

Browse files
Merge pull request #1256 from kapple19/#1085
Fix false positive result of passing a `Function` to `derivative`
2 parents b75c55c + 6fcede0 commit 69cf971

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/diff.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = 1
371371
derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = *(deleteat!(collect(args), i)...)
372372
derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0
373373

374+
derivative(f::Function, x::Num) = derivative(f(x), x)
375+
derivative(::Function, x::Any) = TypeError(:derivative, "2nd argument", Num, typeof(x)) |> throw
376+
374377
function count_order(x)
375378
@assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!"
376379
n = 1

test/diff.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,23 @@ let
378378
Dt = Differential(t)^0
379379
@test isequal(Dt, identity)
380380
test_equal(Dt(t + 2t^2), t + 2t^2)
381-
end
381+
end
382+
383+
# Check `Function` inputs for derivative (#1085)
384+
let
385+
@variables x
386+
@testset for f in [sqrt, sin, acos, exp, cis]
387+
@test isequal(
388+
Symbolics.derivative(f, x),
389+
Symbolics.derivative(f(x), x)
390+
)
391+
end
392+
end
393+
394+
# Check `Function` inputs throw for non-Num second input (#1085)
395+
let
396+
@testset for f in [sqrt, sin, acos, exp, cis]
397+
@test_throws TypeError Symbolics.derivative(f, rand())
398+
@test_throws TypeError Symbolics.derivative(f, Val(rand(Int)))
399+
end
400+
end

0 commit comments

Comments
 (0)