-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMixtureDistribution.h
53 lines (48 loc) · 2.33 KB
/
MixtureDistribution.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once
#include "tracking_lib/Distributions/BaseDistribution.h"
// ######################################################################################################################
#include "tracking_lib/TTBTypes/TTBTypes.h"
namespace ttb
{
/// This represents a mixture distribution of other arbitrary Distributions
class MixtureDistribution final : public BaseDistribution
{
public:
MixtureDistribution();
explicit MixtureDistribution(std::unique_ptr<BaseDistribution> comp);
explicit MixtureDistribution(std::vector<std::unique_ptr<BaseDistribution>> components);
[[nodiscard]] DISTRIBUTION_TYPE type() const override;
[[nodiscard]] DistributionId id() const override;
[[nodiscard]] REFERENCE_POINT refPoint() const override;
[[nodiscard]] std::vector<BaseDistribution*> dists() override;
[[nodiscard]] std::vector<BaseDistribution const*> dists() const override;
[[nodiscard]] std::string toString(std::string const& prefix = "") const override; // NOLINT
[[nodiscard]] bool isValid() const override;
[[nodiscard]] Matrix const& covariance() const override;
[[nodiscard]] Vector const& mean() const override;
[[nodiscard]] double pdf(Vector const& x) const override;
void set(Vector mean) override;
void set(Matrix cov) override;
void set(double weight) override;
void set(REFERENCE_POINT ref) override;
void set(Vector mean, Matrix cov) override;
void setPriorId(DistributionId new_id) override;
[[nodiscard]] DistributionId priorId() const override;
void resetPriorId() override;
std::size_t pruneWeight(double weightThreshold) override;
std::size_t pruneVar(double varThreshold) override;
std::size_t truncate(std::size_t maxComponents) override;
std::size_t mergeComponents(double max_dist, Components const& comps) override;
void merge(std::unique_ptr<BaseDistribution> dist) override;
void merge(std::vector<std::unique_ptr<BaseDistribution>> others) override;
[[nodiscard]] BaseDistribution const& bestComponent() const override;
[[nodiscard]] std::unique_ptr<BaseDistribution> clone() const override;
[[nodiscard]] double sumWeights() const override;
void multiplyWeights(double factor) override;
private:
mutable std::optional<Vector> _meanCache;
mutable std::optional<Matrix> _covCache;
std::vector<std::unique_ptr<BaseDistribution>> _dists;
DistributionId _id;
};
} // namespace ttb