diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 5806c00f31..76d68fd4f0 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -178,14 +178,9 @@ const ZeroMeanDiagNormal{Axes} = MvNormal{Float64,PDiagMat{Float64,Vector{Float6 const ZeroMeanFullNormal{Axes} = MvNormal{Float64,PDMat{Float64,Matrix{Float64}},Zeros{Float64,1,Axes}} ### Construction -function MvNormal(μ::AbstractVector{T}, Σ::AbstractPDMat{T}) where {T<:Real} - size(Σ, 1) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent.")) - MvNormal{T,typeof(Σ), typeof(μ)}(μ, Σ) -end - function MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real}) - R = Base.promote_eltype(μ, Σ) - MvNormal(convert(AbstractArray{R}, μ), convert(AbstractArray{R}, Σ)) + size(Σ, 1) == length(μ) || throw(DimensionMismatch("The dimensions of mu and Sigma are inconsistent.")) + return MvNormal{Base.promote_eltype(μ, Σ),typeof(Σ),typeof(μ)}(μ, Σ) end # constructor with general covariance matrix @@ -197,8 +192,11 @@ Construct a multivariate normal distribution with mean `μ` and covariance matri MvNormal(μ::AbstractVector{<:Real}, Σ::AbstractMatrix{<:Real}) = MvNormal(μ, PDMat(Σ)) MvNormal(μ::AbstractVector{<:Real}, Σ::Diagonal{<:Real}) = MvNormal(μ, PDiagMat(Σ.diag)) MvNormal(μ::AbstractVector{<:Real}, Σ::Union{Symmetric{<:Real,<:Diagonal{<:Real}},Hermitian{<:Real,<:Diagonal{<:Real}}}) = MvNormal(μ, PDiagMat(Σ.data.diag)) -MvNormal(μ::AbstractVector{<:Real}, Σ::UniformScaling{<:Real}) = - MvNormal(μ, ScalMat(length(μ), Σ.λ)) +function MvNormal(μ::AbstractVector{<:Real}, Σ::UniformScaling{<:Real}) + # Promote `Bool` (`I`) to avoid surprising covariance element types + λ = Σ isa UniformScaling{Bool} ? promote_type(eltype(μ), Bool)(Σ.λ) : Σ.λ + return MvNormal(μ, ScalMat(length(μ), λ)) +end function MvNormal( μ::AbstractVector{<:Real}, Σ::Diagonal{<:Real,<:FillArrays.AbstractFill{<:Real,1}} ) diff --git a/test/multivariate/mvnormal.jl b/test/multivariate/mvnormal.jl index 4af9275b3f..604064931d 100644 --- a/test/multivariate/mvnormal.jl +++ b/test/multivariate/mvnormal.jl @@ -82,9 +82,9 @@ end C = [4. -2. -1.; -2. 5. -1.; -1. -1. 6.] J = inv(C) h = J \ mu - @test typeof(MvNormal(mu, PDMat(Array{Float32}(C)))) == typeof(MvNormal(mu, PDMat(C))) - @test typeof(MvNormal(mu, Array{Float32}(C))) == typeof(MvNormal(mu, PDMat(C))) - @test typeof(@test_deprecated(MvNormal(mu, 2.0f0))) == typeof(@test_deprecated(MvNormal(mu, 2.0))) + @test MvNormal(mu, PDMat(Array{Float32}(C))) isa MvNormal{Float64, PDMat{Float32, Matrix{Float32}}, Vector{Float64}} + @test MvNormal(mu, Array{Float32}(C)) isa MvNormal{Float64, PDMat{Float32, Matrix{Float32}}, Vector{Float64}} + @test @test_deprecated(MvNormal(mu, 2.0f0)) isa MvNormal{Float64, ScalMat{Float32}, Vector{Float64}} @test typeof(MvNormalCanon(h, PDMat(Array{Float32}(J)))) == typeof(MvNormalCanon(h, PDMat(J))) @test typeof(MvNormalCanon(h, Array{Float32}(J))) == typeof(MvNormalCanon(h, PDMat(J))) @@ -102,9 +102,9 @@ end @test typeof(convert(MvNormalCanon{Float64}, d)) == typeof(MvNormalCanon(mu, h, PDMat(J))) @test typeof(convert(MvNormalCanon{Float64}, d.μ, d.h, d.J)) == typeof(MvNormalCanon(mu, h, PDMat(J))) - @test MvNormal(mu, I) === @test_deprecated(MvNormal(mu, 1)) + @test MvNormal(mu, I) === @test_deprecated(MvNormal(mu, 1.0)) @test MvNormal(mu, 9 * I) === @test_deprecated(MvNormal(mu, 3)) - @test MvNormal(mu, 0.25f0 * I) === @test_deprecated(MvNormal(mu, 0.5)) + @test MvNormal(mu, 0.25f0 * I) === @test_deprecated(MvNormal(mu, 0.5f0)) @test MvNormal(mu, I) === MvNormal(mu, Diagonal(Ones(length(mu)))) @test MvNormal(mu, 9 * I) === MvNormal(mu, Diagonal(Fill(9, length(mu))))