Skip to content

Commit 6843559

Browse files
authored
Merge pull request #749 from JuliaDiff/ox/logging
Add rule for with_logger
2 parents b8adca6 + 99b8141 commit 6843559

File tree

6 files changed

+34
-5
lines changed

6 files changed

+34
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.56.0"
3+
version = "1.57.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl")
4343
include("rulesets/Base/sort.jl")
4444
include("rulesets/Base/mapreduce.jl")
4545
include("rulesets/Base/broadcast.jl")
46+
include("rulesets/Base/CoreLogging.jl")
4647

4748
include("rulesets/Distributed/nondiff.jl")
4849

src/rulesets/Base/CoreLogging.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib)
2+
3+
function rrule(
4+
rc::RuleConfig{>:ChainRulesCore.HasReverseMode},
5+
::typeof(Base.CoreLogging.with_logger),
6+
f::Function,
7+
logger::Base.CoreLogging.AbstractLogger,
8+
)
9+
y, f_pb = Base.CoreLogging.with_logger(logger) do
10+
rrule_via_ad(rc, f)
11+
end
12+
with_logger_pullback(ȳ) = (NoTangent(), only(f_pb(ȳ)), NoTangent())
13+
return y, with_logger_pullback
14+
end
15+
16+
@non_differentiable Base.CoreLogging.current_logger(args...)
17+
@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...)
18+
@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...)
19+
@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any)
20+
@non_differentiable Base.CoreLogging.handle_message(::Any...)

src/rulesets/Base/nondiff.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,6 @@ end
483483
@non_differentiable Broadcast.result_style(::Any)
484484
@non_differentiable Broadcast.result_style(::Any, ::Any)
485485

486-
@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...)
487-
@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...)
488-
@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any)
489-
@non_differentiable Base.CoreLogging.handle_message(::Any...)
490486

491487
@non_differentiable Libc.free(::Any)
492488
@non_differentiable Libc.getpid()

test/rulesets/Base/CoreLogging.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib)
2+
@testset "CoreLogging.jl" begin
3+
@testset "with_logger" begin
4+
test_rrule(
5+
Base.CoreLogging.with_logger,
6+
() -> 2.0 * 3.0,
7+
Base.CoreLogging.NullLogger();
8+
check_inferred=false,
9+
)
10+
end
11+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ end
5353
test_method_tables() # Check the global method tables are consistent
5454

5555
# Each file puts all tests inside one or more @testset blocks
56+
include_test("rulesets/Base/CoreLogging.jl")
5657
include_test("rulesets/Base/base.jl")
5758
include_test("rulesets/Base/fastmath_able.jl")
5859
include_test("rulesets/Base/evalpoly.jl")

0 commit comments

Comments
 (0)