Skip to content

Commit 71d44a9

Browse files
authored
Fix Primal + Composite perf (#81)
* Fix perf and test * Use allocated no BenchmarkTools
1 parent e9996d9 commit 71d44a9

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

src/differentials/composite.jl

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,14 @@ backing(x::Tuple) = x
7777
backing(x::NamedTuple) = x
7878
backing(x::Composite) = getfield(x, :backing)
7979

80-
function backing(x::T)::NamedTuple where T
81-
!isstructtype(T) && throw(DomainError(T, "backing can only be use on composite types"))
82-
nfields = fieldcount(T)
83-
names = ntuple(ii->fieldname(T, ii), nfields)
84-
types = ntuple(ii->fieldtype(T, ii), nfields)
85-
86-
if @generated
87-
# @btime (()->ChainRulesCore.backing(Foo(1.0, 2.0)))()
88-
## 5.590 ns (1 allocation: 32 bytes)
89-
90-
vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...)
91-
return :(NamedTuple{$names, Tuple{$(types...)}}($vals))
92-
else
93-
vals = ntuple(ii->getfield(x, ii), nfields)
94-
return NamedTuple{names, Tuple{types...}}(vals)
95-
end
80+
@generated function backing(x)::NamedTuple
81+
!isstructtype(x) && throw(DomainError(x, "backing can only be use on composite types"))
82+
nfields = fieldcount(x)
83+
names = ntuple(ii->fieldname(x, ii), nfields)
84+
types = ntuple(ii->fieldtype(x, ii), nfields)
85+
86+
vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...)
87+
return :(NamedTuple{$names, Tuple{$(types...)}}($vals))
9688
end
9789

9890
"""

test/differentials/composite.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ struct Foo
44
y::Float64
55
end
66

7+
# For testing Primal + Composite performance
8+
struct Bar
9+
x::Float64
10+
end
11+
712
# For testing Composite: it is an invarient of the type that x2 = 2x
813
# so simple addition can not be defined
914
struct StructWithInvariant
@@ -93,6 +98,7 @@ end
9398
@testset "Structs" begin
9499
@test Foo(3.5, 1.5) + Composite{Foo}(x=2.5) == Foo(6.0, 1.5)
95100
@test Composite{Foo}(x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5)
101+
@test (@allocated Bar(0.5) + Composite{Bar}(; x=0.5)) == 0
96102
end
97103

98104
@testset "Tuples" begin

0 commit comments

Comments
 (0)