Skip to content

Commit b6b1535

Browse files
authored
Reduce allocations of cov(::MultivariateMixture) (#1967)
* Fix MixtureModel with ScalMat covariances * Use `mul!` * Use `LinearAlgebra.copytri!` * Update src/mixtures/mixturemodel.jl * Add more tests
1 parent da3a230 commit b6b1535

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/mixtures/mixturemodel.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,11 @@ function cov(d::MultivariateMixture)
244244
pi = p[i]
245245
if pi > 0.0
246246
c = component(d, i)
247-
# todo: use more in-place operations
248-
md = mean(c) - m
249-
axpy!(pi, md*md', V)
247+
md .= mean(c) .- m
248+
BLAS.syr!('U', Float64(pi), md, V)
250249
end
251250
end
251+
LinearAlgebra.copytri!(V, 'U')
252252
return V
253253
end
254254

test/mixture.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Distributions, Random
22
using Test
3+
using LinearAlgebra
34
using ForwardDiff: Dual
45

56

@@ -252,17 +253,19 @@ end
252253
end
253254

254255
@testset "Testing MultivariatevariateMixture" begin
255-
g_m = MixtureModel(
256-
IsoNormal[ MvNormal([0.0, 0.0], I),
257-
MvNormal([0.2, 1.0], I),
258-
MvNormal([-0.5, -3.0], 1.6 * I) ],
259-
[0.2, 0.5, 0.3])
260-
@test isa(g_m, MixtureModel{Multivariate, Continuous, IsoNormal})
261-
@test length(components(g_m)) == 3
262-
@test length(g_m) == 2
263-
@test insupport(g_m, [0.0, 0.0]) == true
264-
test_mixture(g_m, 1000, 10^6, rng)
265-
test_params(g_m)
256+
for T in (Float32, Float64)
257+
g_m = MixtureModel(
258+
IsoNormal[ MvNormal([0.0, 0.0], I),
259+
MvNormal([0.2, 1.0], I),
260+
MvNormal([-0.5, -3.0], 1.6 * I) ],
261+
T[0.2, 0.5, 0.3])
262+
@test isa(g_m, MixtureModel{Multivariate, Continuous, IsoNormal})
263+
@test length(components(g_m)) == 3
264+
@test length(g_m) == 2
265+
@test insupport(g_m, [0.0, 0.0])
266+
test_mixture(g_m, 1000, 10^6, rng)
267+
test_params(g_m)
268+
end
266269

267270
u1 = Uniform()
268271
u2 = Uniform(1.0, 2.0)

0 commit comments

Comments
 (0)