Skip to content

Commit 6a8e18b

Browse files
devmotionsethaxenwilltebbuttmzgubic
authored
Add realdot (#2)
* Add missing complex tests and rules (#216) * Fix indentation * Test \ on complex inputs * Test ^ on complex inputs * Test identity on complex inputs * Test muladd on complex inputs * Test binary functions on complex inputs * Test functions on complex inputs * Release type constraint on exp * Add _realconjtimes * Use _realconjtimes in abs/abs2 rules * Add complex rule for hypot * Add generic rule for adjoint * Add generic rule for real * Add generic rule for imag * Add complex rule for hypot * Add rules/tests for Complex * Test frule for identity * Add missing angle test * Make inline just in case * Unify abs rules * Introduce _imagconjtimes utility function * Unify angle rules * Unify sign rules * Multiply by correct variable * Fix argument order * Bump ChainRulesTestUtils version number * Restrict to Complex * Use muladd * Update src/rulesets/Base/fastmath_able.jl Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> * rename differentials (#413) * rename DoesNotExist * rename Composite * bump version and compat * rename Zero * remove typos * reexport deprecated types manually * Rename to `realconjtimes` and `imagconjtimes` and export them * Add tests * Fix tests with Julia 1.0 * Rename to `realdot` and `imagdot` * Add dispatch for real arrays * Update src/utils.jl Co-authored-by: Seth Axen <seth.axen@gmail.com> * Generalize `::Complex` to `::Number` * Rename `utils.jl` to `complex_math.jl` * Remove `imagdot` * Add `realdot` * Update README * Apply suggestions from code review Co-authored-by: Seth Axen <seth.axen@gmail.com> * Update README.md * Update src/RealDot.jl Co-authored-by: Seth Axen <seth.axen@gmail.com> * Add test with quaternions * Fix quaternion multiplication Co-authored-by: Seth Axen <seth.axen@gmail.com> Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
1 parent ebc598d commit 6a8e18b

File tree

4 files changed

+95
-2
lines changed

4 files changed

+95
-2
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
33
authors = ["David Widmann"]
44
version = "0.1.0"
55

6+
[deps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
69
[compat]
710
julia = "1"
811

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,19 @@
44
[![Coverage](https://codecov.io/gh/devmotion/RealDot.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/devmotion/RealDot.jl)
55
[![Coverage](https://coveralls.io/repos/github/devmotion/RealDot.jl/badge.svg?branch=main)](https://coveralls.io/github/devmotion/RealDot.jl?branch=main)
66
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
7+
8+
This package only contains and exports a single function `realdot(x, y)`.
9+
It computes `real(LinearAlgebra.dot(x, y))` while avoiding computing the imaginary part of `LinearAlgebra.dot(x, y)` if possible.
10+
11+
The real dot product is useful when one treats complex numbers as embedded in a real vector space.
12+
For example, take two complex arrays `x` and `y`.
13+
Their real dot product is `real(dot(x, y)) == dot(real(x), real(y)) + dot(imag(x), imag(y))`.
14+
This is the same result one would get by reinterpreting the arrays as real arrays:
15+
```julia
16+
xreal = reinterpret(real(eltype(x)), x)
17+
yreal = reinterpret(real(eltype(y)), y)
18+
real(dot(x, y)) == dot(xreal, yreal)
19+
```
20+
21+
In particular, this function can be useful if you define pullbacks for non-holomorphic functions (see e.g. [this discussion in the ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/pull/474)).
22+
It was implemented initially in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) in [this PR](https://github.com/JuliaDiff/ChainRules.jl/pull/216) as `_realconjtimes`.

src/RealDot.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
module RealDot
22

3-
# Write your package code here.
3+
using LinearAlgebra: LinearAlgebra
4+
5+
export realdot
6+
7+
"""
8+
realdot(x, y)
9+
10+
Compute `real(dot(x, y))` while avoiding computing the imaginary part if possible.
11+
12+
This function can be useful if you work with derivatives of functions on complex
13+
numbers. In particular, this computation shows up in pullbacks for non-holomorphic
14+
functions.
15+
"""
16+
@inline realdot(x, y) = real(LinearAlgebra.dot(x, y))
17+
@inline realdot(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y))
18+
@inline realdot(x::Real, y::Number) = x * real(y)
19+
@inline realdot(x::Number, y::Real) = real(x) * y
20+
@inline realdot(x::Real, y::Real) = x * y
421

522
end

test/runtests.jl

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,63 @@
11
using RealDot
2+
using LinearAlgebra
23
using Test
34

5+
# struct need to be defined outside of tests for julia 1.0 compat
6+
# custom complex number (tests fallback definition)
7+
struct CustomComplex{T} <: Number
8+
re::T
9+
im::T
10+
end
11+
12+
Base.real(x::CustomComplex) = x.re
13+
Base.imag(x::CustomComplex) = x.im
14+
15+
Base.conj(x::CustomComplex) = CustomComplex(x.re, -x.im)
16+
17+
function Base.:*(x::CustomComplex, y::Union{Real,Complex})
18+
return CustomComplex(reim(Complex(reim(x)...) * y)...)
19+
end
20+
Base.:*(x::Union{Real,Complex}, y::CustomComplex) = y * x
21+
function Base.:*(x::CustomComplex, y::CustomComplex)
22+
return CustomComplex(reim(Complex(reim(x)...) * Complex(reim(y)...))...)
23+
end
24+
25+
# custom quaternion to test definition for hypercomplex numbers
26+
# adapted from Quaternions.jl
27+
struct Quaternion{T<:Real} <: Number
28+
s::T
29+
v1::T
30+
v2::T
31+
v3::T
32+
end
33+
34+
Base.real(q::Quaternion) = q.s
35+
Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3)
36+
37+
function Base.:*(q::Quaternion, w::Quaternion)
38+
return Quaternion(
39+
q.s * w.s - q.v1 * w.v1 - q.v2 * w.v2 - q.v3 * w.v3,
40+
q.s * w.v1 + q.v1 * w.s + q.v2 * w.v3 - q.v3 * w.v2,
41+
q.s * w.v2 - q.v1 * w.v3 + q.v2 * w.s + q.v3 * w.v1,
42+
q.s * w.v3 + q.v1 * w.v2 - q.v2 * w.v1 + q.v3 * w.s,
43+
)
44+
end
45+
46+
function Base.:*(q::Quaternion, w::Union{Real,Complex,CustomComplex})
47+
a, b = reim(w)
48+
return q * Quaternion(a, b, zero(a), zero(a))
49+
end
50+
Base.:*(w::Union{Real,Complex,CustomComplex}, q::Quaternion) = conj(conj(q) * conj(w))
51+
452
@testset "RealDot.jl" begin
5-
# Write your tests here.
53+
scalars = (
54+
randn(), randn(ComplexF64), CustomComplex(randn(2)...), Quaternion(randn(4)...)
55+
)
56+
arrays = (randn(10), randn(ComplexF64, 10))
57+
58+
for inputs in (scalars, arrays)
59+
for x in inputs, y in inputs
60+
@test realdot(x, y) == real(dot(x, y))
61+
end
62+
end
663
end

0 commit comments

Comments
 (0)