Skip to content

Commit 5818173

Browse files
authored
Broadcasting (#644)
* broadcasting, adapted from Diffractor PR68 * many small upgrades * fixup tuplecast * re-organise split bc, add forward mode * fix tests * add Yota to downstream tests * fix an ambiguity * fix tests on 1.6 * testing * improve unbroadcast * change generic rule to use BroadcastStyle * debug * rename with unzip * fix for 1.6 * test bugs * version * tidy unzipped * add some GPU tests * remove fallback unbroadcast method * re-instate the error which breaks Revise
1 parent d53d8d8 commit 5818173

File tree

13 files changed

+877
-6
lines changed

13 files changed

+877
-6
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ jobs:
1616
os: [ubuntu-latest]
1717
package:
1818
# - {user: dpsanders, repo: ReversePropagation.jl}
19+
- {user: dfdx, repo: Yota.jl}
1920
- {user: FluxML, repo: Zygote.jl}
2021
# Diffractor needs to run on Julia nightly
2122
# include:

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.42.0"
3+
version = "1.43.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1617

1718
[compat]
1819
ChainRulesCore = "1.15.3"
@@ -25,6 +26,7 @@ JLArrays = "0.1"
2526
JuliaInterpreter = "0.8,0.9"
2627
RealDot = "0.1"
2728
StaticArrays = "1.2"
29+
StructArrays = "0.6.11"
2830
julia = "1.6"
2931

3032
[extras]

src/ChainRules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
44
using ChainRulesCore
55
using Compat
66
using Distributed
7+
using GPUArraysCore: AbstractGPUArrayStyle
78
using IrrationalConstants: logtwo, logten
89
using LinearAlgebra
910
using LinearAlgebra.BLAS
1011
using Random
1112
using RealDot: realdot
1213
using SparseArrays
1314
using Statistics
15+
using StructArrays
1416

1517
# Basically everything this package does is overloading these, so we make an exception
1618
# to the normal rule of only overload via `ChainRulesCore.rrule`.
@@ -22,6 +24,9 @@ using ChainRulesCore: derivatives_given_output
2224
# numbers that we know commute under multiplication
2325
const CommutativeMulNumber = Union{Real,Complex}
2426

27+
# StructArrays
28+
include("unzipped.jl")
29+
2530
include("rulesets/Core/core.jl")
2631

2732
include("rulesets/Base/utils.jl")
@@ -34,6 +39,7 @@ include("rulesets/Base/arraymath.jl")
3439
include("rulesets/Base/indexing.jl")
3540
include("rulesets/Base/sort.jl")
3641
include("rulesets/Base/mapreduce.jl")
42+
include("rulesets/Base/broadcast.jl")
3743

3844
include("rulesets/Distributed/nondiff.jl")
3945

src/rulesets/Base/base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
7272
return (T(x, y), Complex_pullback)
7373
end
7474

75+
@scalar_rule complex(x) true
76+
7577
# `hypot`
7678

7779
@scalar_rule hypot(x::Real) sign(x)

0 commit comments

Comments
 (0)