Skip to content

Commit b1ef2b7

Browse files
committed
Unitful compatibility
- Add separate type parameter for the angle - Promote types for combinations of AbstractFloat and Integer
1 parent 00f11e2 commit b1ef2b7

File tree

4 files changed

+117
-15
lines changed

4 files changed

+117
-15
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ StaticArrays = "0.11,0.12"
1313
julia = "1"
1414

1515
[extras]
16+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1617
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819

1920
[targets]
20-
test = ["Test", "ForwardDiff"]
21+
test = ["Test", "ForwardDiff", "Unitful"]

src/coordinatesystems.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
### 2D Coordinate systems ###
33
#############################
44
"""
5-
`Polar{T}(r::T, θ::T)` - 2D polar coordinates
5+
`Polar{T,A}(r::T, θ::A)` - 2D polar coordinates
66
"""
7-
struct Polar{T}
7+
struct Polar{T,A}
88
r::T
9-
θ::T
9+
θ::A
1010
end
11+
Polar(r::T, θ::A) where {T<:AbstractFloat, A<:Integer} = Polar(promote(r, θ)...)
12+
Polar(r::T, θ::A) where {T<:Integer, A<:AbstractFloat} = Polar(promote(r, θ)...)
1113
Base.show(io::IO, x::Polar) = print(io, "Polar(r=$(x.r), θ=$(x.θ) rad)")
1214
Base.isapprox(p1::Polar, p2::Polar; kwargs...) = isapprox(p1.r, p2.r; kwargs...) && isapprox(p1.θ, p2.θ; kwargs...)
13-
Base.eltype(::Polar{T}) where {T} = T
14-
Base.eltype(::Type{Polar{T}}) where {T} = T
15+
Base.eltype(::Polar{T,A}) where {T,A} = promote_type(T, A)
16+
Base.eltype(::Type{Polar{T,A}}) where {T,A} = promote_type(T, A)
1517

1618
"`PolarFromCartesian()` - transformation from `AbstractVector` of length 2 to `Polar` type"
1719
struct PolarFromCartesian <: Transformation; end
@@ -67,28 +69,34 @@ Base.convert(::Type{Polar}, v::AbstractVector) = PolarFromCartesian()(v)
6769
"""
6870
Spherical(r, θ, ϕ) - 3D spherical coordinates
6971
"""
70-
struct Spherical{T}
72+
struct Spherical{T,A}
7173
r::T
72-
θ::T
73-
ϕ::T
74+
θ::A
75+
ϕ::A
7476
end
77+
Spherical(r::T, θ::A, ϕ::A) where {T<:AbstractFloat, A<:Integer} = Spherical(promote(r, θ, ϕ)...)
78+
Spherical(r::T, θ::A, ϕ::A) where {T<:Integer, A<:AbstractFloat} = Spherical(promote(r, θ, ϕ)...)
7579
Base.show(io::IO, x::Spherical) = print(io, "Spherical(r=$(x.r), θ=$(x.θ) rad, ϕ=$(x.ϕ) rad)")
7680
Base.isapprox(p1::Spherical, p2::Spherical; kwargs...) = isapprox(p1.r, p2.r; kwargs...) && isapprox(p1.θ, p2.θ; kwargs...) && isapprox(p1.ϕ, p2.ϕ; kwargs...)
77-
Base.eltype(::Spherical{T}) where {T} = T
78-
Base.eltype(::Type{Spherical{T}}) where {T} = T
81+
Base.eltype(::Spherical{T,A}) where {T,A} = promote_type(T, A)
82+
Base.eltype(::Type{Spherical{T,A}}) where {T,A} = promote_type(T, A)
7983

8084
"""
8185
Cylindrical(r, θ, z) - 3D cylindrical coordinates
8286
"""
83-
struct Cylindrical{T}
87+
struct Cylindrical{T,A}
8488
r::T
85-
θ::T
89+
θ::A
8690
z::T
8791
end
92+
Cylindrical(r::T1, θ::A, z::T2) where {T1<:AbstractFloat, T2<:Integer, A<:Integer} = Cylindrical(promote(r, θ, z)...)
93+
Cylindrical(r::T1, θ::A, z::T2) where {T1<:Integer, T2<:AbstractFloat, A<:Integer} = Cylindrical(promote(r, θ, z)...)
94+
Cylindrical(r::T1, θ::A, z::T2) where {T1<:AbstractFloat, T2<:Integer, A<:AbstractFloat} = Cylindrical(promote(r, θ, z)...)
95+
Cylindrical(r::T1, θ::A, z::T2) where {T1<:Integer, T2<:AbstractFloat, A<:AbstractFloat} = Cylindrical(promote(r, θ, z)...)
8896
Base.show(io::IO, x::Cylindrical) = print(io, "Cylindrical(r=$(x.r), θ=$(x.θ) rad, z=$(x.z))")
8997
Base.isapprox(p1::Cylindrical, p2::Cylindrical; kwargs...) = isapprox(p1.r, p2.r; kwargs...) && isapprox(p1.θ, p2.θ; kwargs...) && isapprox(p1.z, p2.z; kwargs...)
90-
Base.eltype(::Cylindrical{T}) where {T} = T
91-
Base.eltype(::Type{Cylindrical{T}}) where {T} = T
98+
Base.eltype(::Cylindrical{T,A}) where {T,A} = promote_type(T, A)
99+
Base.eltype(::Type{Cylindrical{T,A}}) where {T,A} = promote_type(T, A)
92100

93101
"`SphericalFromCartesian()` - transformation from 3D point to `Spherical` type"
94102
struct SphericalFromCartesian <: Transformation; end

test/coordinatesystems.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,34 @@
9999
partials(xy_gn[2], 1) partials(xy_gn[2], 2) ]
100100
m = transform_deriv(c_from_p, rθ)
101101
@test m m_gn
102+
103+
@testset "Common types" begin
104+
xy = SVector(1.0, 2.0)
105+
xy_i = SVector(1,2)
106+
p1 = Polar(1, 2.0f0)
107+
p2 = Polar(1.0, 2)
108+
= Polar(2.23606797749979, 1.1071487177940904)
109+
110+
@test eltype(p1) == Float32
111+
@test typeof(p1.r) == typeof(p1.θ)
112+
@test eltype(p2) == Float64
113+
@test typeof(p2.r) == typeof(p2.θ)
114+
@test eltype(p_from_c(xy_i)) == Float64
115+
116+
@test p_from_c(xy_i)
117+
@test p_from_c(xy)
118+
@test p_from_c(collect(xy))
119+
@test c_from_p(rθ) xy
120+
end
121+
122+
@testset "Units" begin
123+
xy = SVector(1.0, 2.0)u"m"
124+
= Polar(2.23606797749979u"m", 1.1071487177940904)
125+
126+
@test p_from_c(xy)
127+
@test p_from_c(collect(xy))
128+
@test c_from_p(rθ) xy
129+
end
102130
end
103131

104132
@testset "3D" begin
@@ -435,5 +463,69 @@
435463
# @test isapprox(m, m_gn; atol = 1e-12)
436464
@test m m_gn
437465

466+
@testset "Common types" begin
467+
xyz = SVector(1.0, 2.0, 3.0)
468+
xyz_i = SVector(1, 2, 3)
469+
470+
@testset "Spherical" begin
471+
rθϕ = Spherical(3.7416573867739413, 1.1071487177940904, 0.9302740141154721)
472+
473+
@test s_from_cart(xyz) rθϕ
474+
@test s_from_cart(xyz_i) rθϕ
475+
@test s_from_cart(collect(xyz)) rθϕ
476+
@test cart_from_s(rθϕ) xyz
477+
478+
s1 = Spherical(1, 2.0, 3.0)
479+
s2 = Spherical(1.0, 2, 3)
480+
481+
@test eltype(s1) == Float64
482+
@test typeof(s1.r) == typeof(s1.θ) == typeof(s1.ϕ)
483+
@test eltype(s2) == Float64
484+
@test typeof(s2.r) == typeof(s2.θ) == typeof(s2.ϕ)
485+
@test eltype(s_from_cart(xyz_i)) == Float64
486+
end
487+
488+
@testset "Cylindrical" begin
489+
rθz = Cylindrical(2.23606797749979, 1.1071487177940904, 3.0)
490+
491+
@test cyl_from_cart(xyz) rθz
492+
@test cyl_from_cart(xyz_i) rθz
493+
@test cyl_from_cart(collect(xyz)) rθz
494+
@test cart_from_cyl(rθz) xyz
495+
496+
c1 = Cylindrical(1, 2.0, 3)
497+
c2 = Cylindrical(1.0, 2, 3.0)
498+
c3 = Cylindrical(1, 2, 3)
499+
500+
@test eltype(c1) == Float64
501+
@test typeof(c1.r) == typeof(c1.z)
502+
@test typeof(c1.θ) == Float64
503+
@test eltype(c2) == Float64
504+
@test typeof(c2.r) == typeof(c2.z)
505+
@test typeof(c2.θ) == Int
506+
@test eltype(cyl_from_cart(xyz_i)) == Float64
507+
end
508+
end
509+
510+
@testset "Units" begin
511+
xyz = SVector(1.0, 2.0, 3.0)u"m"
512+
513+
@testset "Shperical" begin
514+
rθϕ = Spherical(3.7416573867739413u"m", 1.1071487177940904, 0.9302740141154721)
515+
516+
@test s_from_cart(xyz) rθϕ
517+
@test typeof(s_from_cart(xyz)) == typeof(rθϕ)
518+
@test s_from_cart(collect(xyz)) rθϕ
519+
@test cart_from_s(rθϕ) xyz
520+
end
521+
@testset "Cylindrical" begin
522+
rθz = Cylindrical(2.23606797749979u"m", 1.1071487177940904, 3.0u"m")
523+
524+
@test cyl_from_cart(xyz) rθz
525+
@test typeof(cyl_from_cart(xyz)) == typeof(rθz)
526+
@test cyl_from_cart(collect(xyz)) rθz
527+
@test cart_from_cyl(rθz) xyz
528+
end
529+
end
438530
end
439531
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using LinearAlgebra
33
using CoordinateTransformations
44
using ForwardDiff: Dual, partials
55
using StaticArrays
6+
using Unitful
67

78
@testset "CoordinateTransformations" begin
89

0 commit comments

Comments
 (0)