-
Notifications
You must be signed in to change notification settings - Fork 4
Init: Refactor NMM mixture, separate generators and mixtures, tests r… #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
f9bf7c6
Init: Refactor NMM mixture, separate generators and mixtures, tests r…
Engelsgeduld 78a4ac7
Init: Refactor. Add estimator classes and directory structure
andreev-sergej 9d61917
Fix: Fix tests.
andreev-sergej 35646aa
Fix: Switch to one register. Add algorithm purpose key to register.
andreev-sergej 9b94ad3
Complete architecture changes, edit docstrings, refactor tests
Engelsgeduld 9710cea
Fix: black, isort
Engelsgeduld b4fac7d
Fix: params validation refactor, add docstrings
Engelsgeduld 9160868
Fix: Enum purpose key
andreev-sergej 533b136
Fix: Add docstrings
andreev-sergej d7baea4
Fix: Default purpose value
andreev-sergej 032bf44
Fix: Algorithm registry init
andreev-sergej 601fe1c
Fix: Remove unused imports
andreev-sergej File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from src.algorithms.semiparam_algorithms.nvm_semi_param_algorithms.mu_estimation import SemiParametricMuEstimation | ||
from src.register.register import Registry | ||
|
||
ALGORITHM_REGISTRY: Registry = Registry() | ||
ALGORITHM_REGISTRY.register("mu_estimation", "NMVSemiparametric")(SemiParametricMuEstimation) |
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from abc import abstractmethod | ||
from typing import Self | ||
|
||
from numpy import _typing | ||
|
||
from src.algorithms import ALGORITHM_REGISTRY | ||
from src.estimators.estimate_result import EstimateResult | ||
|
||
|
||
class AbstractEstimator: | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
self.algorithm_name = algorithm_name | ||
if params is None: | ||
self.params = dict() | ||
else: | ||
self.params = params | ||
self.estimate_result = EstimateResult() | ||
self._registry = ALGORITHM_REGISTRY | ||
self._purpose = "" | ||
|
||
def get_params(self) -> dict: | ||
return {"algorithm_name": self.algorithm_name, "params": self.params, "estimated_result": self.estimate_result} | ||
|
||
def set_params(self, algorithm_name: str, params: dict | None = None) -> Self: | ||
self.algorithm_name = algorithm_name | ||
if params is None: | ||
self.params = dict() | ||
else: | ||
self.params = params | ||
return self | ||
|
||
def get_available_algorithms(self) -> list[tuple[str, str]]: | ||
return [key for key in self._registry.register_of_names.keys() if key[1] == self._purpose] | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
cls = None | ||
if (self.algorithm_name, self._purpose) in self._registry.register_of_names: | ||
cls = self._registry.dispatch(self.algorithm_name, self._purpose)(sample, **self.params) | ||
if cls is None: | ||
raise ValueError("This algorithm does not exist") | ||
self.estimate_result = cls.algorithm(sample) | ||
return self.estimate_result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class EstimateResult: | ||
value: float = -1 | ||
success: bool = False | ||
message: str = "No message" |
12 changes: 12 additions & 0 deletions
12
src/estimators/parametric/abstract_parametric_estimator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.abstract_estimator import AbstractEstimator | ||
from src.estimators.estimate_result import EstimateResult | ||
|
||
|
||
class AbstractParametricEstimator(AbstractEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.estimate_result import EstimateResult | ||
from src.estimators.parametric.abstract_parametric_estimator import AbstractParametricEstimator | ||
|
||
|
||
class NMParametricEstimator(AbstractParametricEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
self._purpose = "NMParametric" | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.estimate_result import EstimateResult | ||
from src.estimators.parametric.abstract_parametric_estimator import AbstractParametricEstimator | ||
|
||
|
||
class NMVParametricEstimator(AbstractParametricEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
self._purpose = "NMVParametric" | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.estimate_result import EstimateResult | ||
from src.estimators.parametric.abstract_parametric_estimator import AbstractParametricEstimator | ||
|
||
|
||
class NVParametricEstimator(AbstractParametricEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
self._purpose = "NVParametric" | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
14 changes: 14 additions & 0 deletions
14
src/estimators/semiparametric/abstract_semiparametric_estimator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from abc import abstractmethod | ||
|
||
from numpy import _typing | ||
|
||
from src.estimators.abstract_estimator import AbstractEstimator | ||
from src.estimators.estimate_result import EstimateResult | ||
|
||
|
||
class AbstractSemiParametricEstimator(AbstractEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
13 changes: 13 additions & 0 deletions
13
src/estimators/semiparametric/nm_semiparametric_estimator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.estimate_result import EstimateResult | ||
from src.estimators.semiparametric.abstract_semiparametric_estimator import AbstractSemiParametricEstimator | ||
|
||
|
||
class NMSemiParametricEstimator(AbstractSemiParametricEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
self._purpose = "NMSemiparametric" | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
13 changes: 13 additions & 0 deletions
13
src/estimators/semiparametric/nmv_semiparametric_estimator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.estimate_result import EstimateResult | ||
from src.estimators.semiparametric.abstract_semiparametric_estimator import AbstractSemiParametricEstimator | ||
|
||
|
||
class NMVSemiParametricEstimator(AbstractSemiParametricEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
self._purpose = "NMVSemiparametric" | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
13 changes: 13 additions & 0 deletions
13
src/estimators/semiparametric/nv_semiparametric_estimator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from numpy import _typing | ||
|
||
from src.estimators.estimate_result import EstimateResult | ||
from src.estimators.semiparametric.abstract_semiparametric_estimator import AbstractSemiParametricEstimator | ||
|
||
|
||
class NVSemiParametricEstimator(AbstractSemiParametricEstimator): | ||
def __init__(self, algorithm_name: str, params: dict | None = None) -> None: | ||
super().__init__(algorithm_name, params) | ||
self._purpose = "NVSemiparametric" | ||
|
||
def estimate(self, sample: _typing.ArrayLike) -> EstimateResult: | ||
return super().estimate(sample) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from abc import abstractmethod | ||
|
||
import numpy._typing as tpg | ||
|
||
from src.mixtures.abstract_mixture import AbstractMixtures | ||
|
||
|
||
class AbstractGenerator: | ||
@staticmethod | ||
@abstractmethod | ||
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Докстриги лучше здесь всё же написать |
||
|
||
"""Generate a sample of given size. Classical form of Mixture | ||
|
||
Args: | ||
mixture: NMM | NVM | NMVM | ||
size: length of sample | ||
|
||
Returns: sample of given size | ||
|
||
""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: ... | ||
|
||
"""Generate a sample of given size. Canonical form of Mixture | ||
|
||
Args: | ||
mixture: NMM | NVM | NMVM | ||
size: length of sample | ||
|
||
Returns: sample of given size | ||
|
||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy._typing as tpg | ||
import scipy | ||
|
||
from src.generators.abstract_generator import AbstractGenerator | ||
from src.mixtures.abstract_mixture import AbstractMixtures | ||
from src.mixtures.nm_mixture import NormalMeanMixtures | ||
|
||
|
||
class NMGenerator(AbstractGenerator): | ||
|
||
@staticmethod | ||
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: | ||
"""Generate a sample of given size. Classical form of NMM | ||
|
||
Args: | ||
mixture: Normal Mean Mixture | ||
size: length of sample | ||
|
||
Returns: sample of given size | ||
|
||
Raises: | ||
ValueError: If mixture is not a Normal Mean Mixture | ||
|
||
""" | ||
|
||
if not isinstance(mixture, NormalMeanMixtures): | ||
raise ValueError("Mixture must be NormalMeanMixtures") | ||
mixing_values = mixture.params.distribution.rvs(size=size) | ||
normal_values = scipy.stats.norm.rvs(size=size) | ||
return mixture.params.alpha + mixture.params.beta * mixing_values + mixture.params.gamma * normal_values | ||
|
||
@staticmethod | ||
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: | ||
"""Generate a sample of given size. Canonical form of NMM | ||
|
||
Args: | ||
mixture: Normal Mean Mixture | ||
size: length of sample | ||
|
||
Returns: sample of given size | ||
|
||
Raises: | ||
ValueError: If mixture is not a Normal Mean Mixture | ||
|
||
""" | ||
|
||
if not isinstance(mixture, NormalMeanMixtures): | ||
raise ValueError("Mixture must be NormalMeanMixtures") | ||
mixing_values = mixture.params.distribution.rvs(size=size) | ||
normal_values = scipy.stats.norm.rvs(size=size) | ||
return mixing_values + mixture.params.sigma * normal_values |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy._typing as tpg | ||
import scipy | ||
|
||
from src.generators.abstract_generator import AbstractGenerator | ||
from src.mixtures.abstract_mixture import AbstractMixtures | ||
from src.mixtures.nmv_mixture import NormalMeanVarianceMixtures | ||
|
||
|
||
class NMVGenerator(AbstractGenerator): | ||
|
||
@staticmethod | ||
def classical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: | ||
"""Generate a sample of given size. Classical form of NMVM | ||
|
||
Args: | ||
mixture: Normal Mean Variance Mixtures | ||
size: length of sample | ||
|
||
Returns: sample of given size | ||
|
||
Raises: | ||
ValueError: If mixture type is not Normal Mean Variance Mixtures | ||
|
||
""" | ||
|
||
if not isinstance(mixture, NormalMeanVarianceMixtures): | ||
raise ValueError("Mixture must be NormalMeanMixtures") | ||
mixing_values = mixture.params.distribution.rvs(size=size) | ||
normal_values = scipy.stats.norm.rvs(size=size) | ||
return ( | ||
mixture.params.alpha | ||
+ mixture.params.beta * mixing_values | ||
+ mixture.params.gamma * (mixing_values**0.5) * normal_values | ||
) | ||
|
||
@staticmethod | ||
def canonical_generate(mixture: AbstractMixtures, size: int) -> tpg.NDArray: | ||
"""Generate a sample of given size. Canonical form of NMVM | ||
|
||
Args: | ||
mixture: Normal Mean Variance Mixtures | ||
size: length of sample | ||
|
||
Returns: sample of given size | ||
|
||
Raises: | ||
ValueError: If mixture type is not Normal Mean Variance Mixtures | ||
|
||
""" | ||
|
||
if not isinstance(mixture, NormalMeanVarianceMixtures): | ||
raise ValueError("Mixture must be NormalMeanMixtures") | ||
mixing_values = mixture.params.distribution.rvs(size=size) | ||
normal_values = scipy.stats.norm.rvs(size=size) | ||
return mixture.params.alpha + mixture.params.mu * mixing_values + (mixing_values**0.5) * normal_values |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Вот тут много возьни со строками и словарями. В будущем рекомендую задуматься о рефакторинге, так как поддерживать может быть больно