diff --git a/aeon/base/_base.py b/aeon/base/_base.py index 5a336c7397..539127aa26 100644 --- a/aeon/base/_base.py +++ b/aeon/base/_base.py @@ -21,19 +21,20 @@ class BaseAeonEstimator(BaseEstimator, ABC): Contains the following methods: - - reset estimator to post-init - reset(keep) - - clone stimator (copy) - clone(random_state) - - inspect tags (class method) - get_class_tags() - - inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default, - raise_error) - - inspect tags (all) - get_tags() - - inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error) - - setting dynamic tags - set_tags(**tag_dict) - - get fitted parameters - get_fitted_params(deep) + - reset estimator to post-init - ``reset(keep)`` + - clone stimator (copy) - ``clone(random_state)`` + - inspect tags (class method) - ``get_class_tags()`` + - inspect tags (one tag, class) - ``get_class_tag(tag_name, tag_value_default + , raise_error)`` + - inspect tags (all) - ``get_tags()`` + - inspect tags (one tag) - ``get_tag(tag_name, tag_value_default + , raise_error)`` + - setting dynamic tags - ``set_tags(**tag_dict)`` + - get fitted parameters - ``get_fitted_params(deep)`` All estimators have the attribute: - - fitted state flag - is_fitted + - fitted state flag - ``is_fitted`` """ _tags = { @@ -59,14 +60,14 @@ def reset(self, keep=None): """ Reset the object to a clean post-init state. - After a ``self.reset()`` call, self is equal or similar in value to + After a ``self.reset()`` call, ``self`` is equal or similar in value to ``type(self)(**self.get_params(deep=False))``, assuming no other attributes were kept using ``keep``. Detailed behaviour: removes any object attributes, except: hyper-parameters (arguments of ``__init__``) - object attributes containing double-underscores, i.e., the string "__" + object attributes containing double-underscores, i.e., the string ``__`` runs ``__init__`` with current values of hyperparameters (result of ``get_params``) @@ -78,9 +79,9 @@ class and object methods, class attributes Parameters ---------- keep : None, str, or list of str, default=None - If None, all attributes are removed except hyperparameters. - If str, only the attribute with this name is kept. - If list of str, only the attributes with these names are kept. + If ``None``, all attributes are removed except hyperparameters. + If ``str``, only the attribute with this name is kept. + If ``list`` of ``str``, only the attributes with these names are kept. Returns ------- @@ -125,15 +126,18 @@ def clone(self, random_state=None): Obtain a clone of the object with the same hyperparameters. A clone is a different object without shared references, in post-init state. - This function is equivalent to returning ``sklearn.clone`` of self. + This function is equivalent to returning ``sklearn.clone`` of ``self``. Equal in value to ``type(self)(**self.get_params(deep=False))``. Parameters ---------- random_state : int, RandomState instance, or None, default=None - Sets the random state of the clone. If None, the random state is not set. - If int, random_state is the seed used by the random number generator. - If RandomState instance, random_state is the random number generator. + Sets the random state of the clone. If ``None``, the random state is not + set. + If ``int``, ``random_state`` is the seed used by the random number + generator. + If ``RandomState`` instance, ``random_state`` is the random number + generator. Returns ------- @@ -187,7 +191,7 @@ def get_class_tag( tag_name : str Name of tag value. raise_error : bool, default=True - Whether a ValueError is raised when the tag is not found. + Whether a ``ValueError`` is raised when the tag is not found. tag_value_default : any type, default=None Default/fallback value if tag is not found and error is not raised. @@ -195,13 +199,13 @@ def get_class_tag( ------- tag_value Value of the ``tag_name`` tag in cls. - If not found, returns an error if ``raise_error`` is True, otherwise it + If not found, returns an error if ``raise_error`` is ``True``, otherwise it returns ``tag_value_default``. Raises ------ ValueError - if ``raise_error`` is True and ``tag_name`` is not in + if ``raise_error`` is ``True`` and ``tag_name`` is not in ``self.get_tags().keys()`` Examples @@ -247,7 +251,7 @@ def get_tag(self, tag_name, raise_error=True, tag_value_default=None): tag_name : str Name of tag to be retrieved. raise_error : bool, default=True - Whether a ValueError is raised when the tag is not found. + Whether a ``ValueError`` is raised when the tag is not found. tag_value_default : any type, default=None Default/fallback value if tag is not found and error is not raised. @@ -255,7 +259,7 @@ def get_tag(self, tag_name, raise_error=True, tag_value_default=None): ------- tag_value Value of the ``tag_name`` tag in self. - If not found, returns an error if ``raise_error`` is True, otherwise it + If not found, returns an error if ``raise_error`` is ``True``, otherwise it returns ``tag_value_default``. Raises @@ -292,7 +296,7 @@ def set_tags(self, **tag_dict): Returns ------- self : object - Reference to self. + Reference to ``self``. """ tag_update = deepcopy(tag_dict) self._tags_dynamic.update(tag_update) @@ -307,7 +311,7 @@ def get_fitted_params(self, deep=True): Parameters ---------- deep : bool, default=True - If True, will return the fitted parameters for this estimator and + If ``True``, will return the fitted parameters for this estimator and contained subobjects that are estimators. Returns @@ -354,7 +358,7 @@ def _check_is_fitted(self): if not self.is_fitted: raise NotFittedError( f"This instance of {self.__class__.__name__} has not " - f"been fitted yet; please call `fit` first." + f"been fitted yet; please call ``fit`` first." ) @classmethod @@ -366,14 +370,15 @@ def _get_test_params(cls, parameter_set="default"): ---------- parameter_set : str, default="default" Name of the set of test parameters to return, for use in tests. If no - special parameters are defined for a value, will return `"default"` set. + special parameters are defined for a value, will return ``default`` set. Returns ------- params : dict or list of dict, default = {} - Parameters to create testing instances of the class. Each dict are - parameters to construct an "interesting" test instance, i.e., - `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + Parameters to create testing instances of the class. Each ``dict`` are + parameters to construct an ``interesting`` test instance, i.e., + ``MyClass(**params)`` or ``MyClass(**params[i])`` creates a valid test + instance. """ # default parameters = empty dict return {} @@ -383,23 +388,23 @@ def _create_test_instance(cls, parameter_set="default", return_first=True): """ Construct Estimator instance if possible. - Calls the `_get_test_params` method and returns an instance or list of instances - using the returned dict or list of dict. + Calls the ``_get_test_params`` method and returns an instance or ``list`` + of instances using the returned ``dict`` or list of ``dict``. Parameters ---------- parameter_set : str, default="default" Name of the set of test parameters to return, for use in tests. If no - special parameters are defined for a value, will return `"default"` set. + special parameters are defined for a value, will return ``default`` set. return_first : bool, default=True - If True, return the first instance of the list of instances. - If False, return the list of instances. + If ``True``, return the first instance of the list of instances. + If ``False``, return the list of instances. Returns ------- instance : BaseAeonEstimator or list of BaseAeonEstimator - Instance of the class with default parameters. If return_first - is False, returns list of instances. + Instance of the class with default parameters. If ``return_first`` + is ``False``, returns list of instances. """ params = cls._get_test_params(parameter_set=parameter_set) @@ -441,7 +446,7 @@ def _validate_data(self, **kwargs): def get_metadata_routing(self): """Sklearn metadata routing. - Not supported by ``aeon`` estimators. + Not supported by aeon estimators. """ raise NotImplementedError( "aeon estimators do not have a get_metadata_routing method." @@ -449,7 +454,7 @@ def get_metadata_routing(self): @classmethod def _get_default_requests(cls): - """Sklearn metadata request defaults.""" + """``Sklearn`` metadata request defaults.""" from sklearn.utils._metadata_requests import MetadataRequest return MetadataRequest(None) diff --git a/aeon/base/_base_collection.py b/aeon/base/_base_collection.py index 4d7f4b4564..0b1c731043 100644 --- a/aeon/base/_base_collection.py +++ b/aeon/base/_base_collection.py @@ -4,19 +4,19 @@ class name: BaseCollectionEstimator Defining methods: - preprocessing - _preprocess_collection(self, X, store_metadata=True) - input checking - _check_X(self, X) - input conversion - _convert_X(self, X) - shape checking - _check_shape(self, X) + preprocessing - ``_preprocess_collection(self, X, store_metadata=True)`` + input checking - ``_check_X(self, X)`` + input conversion - ``_convert_X(self, X)`` + shape checking - ``_check_shape(self, X)`` Inherited inspection methods: - hyper-parameter inspection - get_params() - fitted parameter inspection - get_fitted_params() + hyper-parameter inspection - ``get_params()`` + fitted parameter inspection - ``get_fitted_params()`` State: - fitted model/strategy - by convention, any attributes ending in "_" - fitted state flag - is_fitted (property) - fitted state inspection - check_is_fitted() + fitted model/strategy - by convention, any attributes ending in ``_`` + fitted state flag - ``is_fitted (property)`` + fitted state inspection - ``check_is_fitted()`` """ @@ -80,26 +80,26 @@ def _preprocess_collection(self, X, store_metadata=True): Parameters ---------- X : collection - See aeon.utils.COLLECTIONS_DATA_TYPES for details on aeon supported + See ``aeon.utils.COLLECTIONS_DATA_TYPES`` for details on aeon supported data structures. store_metadata : bool, default=True - Whether to store metadata about X in self.metadata_. + Whether to store metadata about ``X`` in ``self.metadata_``. Returns ------- X : collection - Processed X. A data structure of type self.get_tag("X_inner_type"). + Processed ``X``. A data structure of type ``self.get_tag("X_inner_type")``. Raises ------ ValueError - If X is an invalid type or has characteristics that the estimator cannot + If ``X`` is an invalid type or has characteristics that the estimator cannot handle. See Also -------- _check_X : - Function that checks X is valid before conversion. + Function that checks ``X`` is valid before conversion. _convert_X : Function that converts to inner type. @@ -138,37 +138,39 @@ def _check_X(self, X): Check if the input data is a compatible type, and that this estimator is able to handle the data characteristics. This is done by matching the capabilities of the estimator against the metadata - for X i.e., univariate/multivariate, equal length/unequal length and no missing - values/missing values. + for ``X`` i.e., univariate/multivariate, equal length/unequal length + and no missing values/missing values. Parameters ---------- X : collection - See aeon.utils.COLLECTIONS_DATA_TYPES for details on aeon supported + See ``aeon.utils.COLLECTIONS_DATA_TYPES`` for details on aeon supported data structures. Returns ------- metadata : dict - Metadata about X, with flags: - metadata["multivariate"] : whether X has more than one channel or not - metadata["missing_values"] : whether X has missing values or not - metadata["unequal_length"] : whether X contains unequal length series. - metadata["n_cases"] : number of cases in X - metadata["n_channels"] : number of channels in X - metadata["n_timepoints"] : number of timepoints in X if equal length, else - None + Metadata about ```X```, with flags: + ``metadata["multivariate"]`` : whether ``X`` has more than one channel or + not + ``metadata["missing_values"]`` : whether ``X`` has missing values or not + ``metadata["unequal_length"]`` : whether ``X`` contains unequal length + series. + ``metadata["n_cases"]`` : number of cases in ``X`` + ``metadata["n_channels"]`` : number of channels in ``X`` + ``metadata["n_timepoints"]`` : number of timepoints in ``X`` if equal + length, else ``None`` Raises ------ ValueError - If X is an invalid type or has characteristics that the estimator cannot + If ``X`` is an invalid type or has characteristics that the estimator cannot handle. See Also -------- _convert_X : - Function that converts X after it has been checked. + Function that converts ``X`` after it has been checked. Examples -------- @@ -209,30 +211,31 @@ def _check_X(self, X): def _convert_X(self, X): """ - Convert X to type defined by tag X_inner_type. + Convert ``X`` to type defined by tag ``X_inner_type``. If the input data is already an allowed type, it is returned unchanged. - If multiple types are allowed by self, then the best one for the type of input - data is selected. So, for example, if X_inner_tag is ["np-list", "numpy3D"] - and an df-list is passed, it will be converted to numpy3D if the series - are equal length, and np-list if the series are unequal length. + If multiple types are allowed by ``self``, then the best + one for the type of input data is selected. So, for example, if + ``X_inner_tag`` is ["np-list", "numpy3D"] and an df-list is passed, it will + be converted to ``numpy3D`` if the series are equal length, and np-list + if the series are unequal length. Parameters ---------- X : collection - See aeon.utils.COLLECTIONS_DATA_TYPES for details on aeon supported + See ``aeon.utils.COLLECTIONS_DATA_TYPES`` for details on aeon supported data structures. Returns ------- X : collection - Converted X. A data structure of type self.get_tag("X_inner_type"). + Converted ``X``. A data structure of type ``self.get_tag("X_inner_type")``. See Also -------- _check_X : - Function that checks X is valid and finds metadata. + Function that checks ``X`` is valid and finds metadata. Examples -------- @@ -275,17 +278,17 @@ def _convert_X(self, X): def _check_shape(self, X): """ - Check that the shape of X is consistent with the data seen in fit. + Check that the shape of ``X`` is consistent with the data seen in fit. Parameters ---------- X : data structure - Must be of type aeon.registry.COLLECTIONS_DATA_TYPES. + Must be of type ``aeon.registry.COLLECTIONS_DATA_TYPES``. Raises ------ ValueError - If the shape of X is not consistent with the data seen in fit. + If the shape of ``X`` is not consistent with the data seen in fit. """ # if metadata is empty, then we have not seen any data in fit. If the estimator # has not been fitted, then _is_fitted should catch this. diff --git a/aeon/base/_base_series.py b/aeon/base/_base_series.py index 6c86940f5b..07af84817b 100644 --- a/aeon/base/_base_series.py +++ b/aeon/base/_base_series.py @@ -38,12 +38,12 @@ class BaseSeriesEstimator(BaseAeonEstimator): is the number of channels): Univariate series: np.ndarray, shape ``(m,)``, ``(m, 1)`` or ``(1, m)`` depending on axis. - This is converted to a 2D np.ndarray internally. + This is converted to a 2D ``np.ndarray`` internally. pd.DataFrame, shape ``(m, 1)`` or ``(1, m)`` depending on axis. - pd.Series, shape ``(m,)`` is converted to a pd.DataFrame. + pd.Series, shape ``(m,)`` is converted to a ``pd.DataFrame``. Multivariate series: - np.ndarray array, shape ``(m, d)`` or ``(d, m)`` depending on axis. - pd.DataFrame ``(m, d)`` or ``(d, m)`` depending on axis. + ``np.ndarray`` array, shape ``(m, d)`` or ``(d, m)`` depending on axis. + ``pd.DataFrame`` ``(m, d)`` or ``(d, m)`` depending on axis. Parameters ---------- @@ -70,16 +70,17 @@ def __init__(self, axis): super().__init__() def _preprocess_series(self, X, axis, store_metadata): - """Preprocess input X prior to call to fit. + """Preprocess input ``X`` prior to call to fit. - Checks the characteristics of X, store metadata, checks self can handle - the data then convert X to X_inner_type + Checks the characteristics of ``X``, store metadata, checks self can handle + the data then convert ``X`` to X_inner_type Parameters ---------- X: one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES A valid aeon time series data structure. See - aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types. + ``aeon.base._base_series.VALID_SERIES_INPUT_TYPES`` + for aeon supported types. axis: int The time point axis of the input series if it is 2D. If ``axis==0``, it is assumed each column is a time series and each row is a time point. i.e. the @@ -87,12 +88,13 @@ def _preprocess_series(self, X, axis, store_metadata): the time series are in rows, i.e. the shape of the data is ``(n_channels, n_timepoints)``. store_metadata: bool - If True, overwrite metadata with the new metadata from X. + If ``True``, overwrite metadata with the new metadata from X. Returns ------- X: one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES - Input time series with data structure of type self.get_tag("X_inner_type"). + Input time series with data structure of type + ``self.get_tag("X_inner_type")``. """ meta = self._check_X(X, axis) if store_metadata: @@ -104,14 +106,15 @@ def _check_X(self, X, axis): Check if the input data is a compatible type, and that this estimator is able to handle the data characteristics. This is done by matching the - capabilities of the estimator against the metadata for X for + capabilities of the estimator against the metadata for ``X`` for univariate/multivariate and no missing values/missing values. Parameters ---------- X: one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES A valid aeon time series data structure. See - aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types. + ``aeon.base._base_series.VALID_SERIES_INPUT_TYPES`` + for aeon supported types. axis: int The time point axis of the input series if it is 2D. If ``axis==0``, it is assumed each column is a time series and each row is a time point. i.e. the @@ -122,10 +125,10 @@ def _check_X(self, X, axis): Returns ------- metadata: dict - Metadata about X, with flags: - metadata["multivariate"]: whether X has more than one channel or not - metadata["n_channels"]: number of channels in X - metadata["missing_values"]: whether X has missing values or not + Metadata about ``X``, with flags: + ``metadata["multivariate"]``: whether ``X`` has more than one channel or not + ``metadata["n_channels"]``: number of channels in ``X`` + ``metadata["missing_values"]``: whether ``X`` has missing values or not """ if axis > 1 or axis < 0: raise ValueError(f"Input axis should be 0 or 1, saw {axis}") @@ -192,21 +195,22 @@ def _check_X(self, X, axis): return metadata def _convert_X(self, X, axis): - """Convert input X to internal estimator datatype. + """Convert input ``X`` to internal estimator datatype. - Converts input X to the internal data type of the estimator using - self.get_tag("X_inner_type"). 1D numpy arrays are converted to 2D, + Converts input ``X`` to the internal data type of the estimator using + ``self.get_tag("X_inner_type")``. 1D numpy arrays are converted to 2D, and the data will be transposed if the input axis does not match that of the estimator. - Attempting to convert to a pd.Series for multivariate data or estimators will - raise an error. + Attempting to convert to a ``pd.Series`` for multivariate + data or estimators will raise an error. Parameters ---------- X: one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES A valid aeon time series data structure. See - aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types. + ``aeon.base._base_series.VALID_SERIES_INPUT_TYPES`` for aeon + supported types. axis: int The time point axis of the input series if it is 2D. If ``axis==0``, it is assumed each column is a time series and each row is a time point. i.e. the @@ -217,7 +221,8 @@ def _convert_X(self, X, axis): Returns ------- X: one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES - Input time series with data structure of type self.get_tag("X_inner_type"). + Input time series with data structure of type + ``self.get_tag("X_inner_type")``. """ if axis > 1 or axis < 0: raise ValueError(f"Input axis should be 0 or 1, saw {axis}") diff --git a/aeon/base/_compose.py b/aeon/base/_compose.py index 0995e85de6..25e2fc7bfd 100644 --- a/aeon/base/_compose.py +++ b/aeon/base/_compose.py @@ -12,8 +12,8 @@ class ComposableEstimatorMixin(ABC): """Handles parameter management for estimators composed of named estimators. - Parts (i.e. get_params and set_params) adapted or copied from the scikit-learn - ``_BaseComposition`` class in utils/metaestimators.py. + Parts (i.e. ``get_params`` and ``set_params``) adapted or copied from the + ``scikit-learn`` ``_BaseComposition`` class in ``utils/metaestimators.py``. """ # Attribute name containing an iterable of processed (str, estimator) tuples @@ -36,7 +36,7 @@ def get_params(self, deep=True): Parameters ---------- deep : bool, default=True - If True, will return the parameters for this estimator and + If ``True``, will return the parameters for this estimator and contained subobjects that are estimators. Returns @@ -113,7 +113,7 @@ def get_fitted_params(self, deep=True): Parameters ---------- deep : bool, default=True - If True, will return the fitted parameters for this estimator and + If ``True``, will return the fitted parameters for this estimator and contained subobjects that are estimators. Returns @@ -150,20 +150,20 @@ def _check_estimators( Parameters ---------- estimators : list - A list of estimators or list of (str, estimator) tuples. + A ``list`` of estimators or ``list`` of (``str``, ``estimator``) tuples. attr_name : str, optional. Default = "steps" Name of checked attribute in error messages class_type : class, tuple of class or None, default=BaseAeonEstimator. Class(es) that all estimators in ``estimators`` are checked to be an instance of. allow_tuples : boolean, default=True. - Whether tuples of (str, estimator) are allowed in ``estimators``. - Generally, the end-state we want is a list of tuples, so this should be True - in most cases. + Whether tuples of (``str``, ``estimator``) are allowed in ``estimators``. + Generally, the end-state we want is a ``list`` of tuples, so this should be + ``True`` in most cases. allow_single_estimators : boolean, default=True. Whether non-tuple estimator classes are allowed in ``estimators``. unique_names : boolean, default=True. - Whether to check that all tuple strings in `estimators` are unique. + Whether to check that all tuple strings in ``estimators`` are unique. invalid_names : str, list of str or None, default=None. Names that are invalid for estimators in ``estimators``. @@ -237,18 +237,21 @@ def _convert_estimators(self, estimators, clone_estimators=True): Parameters ---------- - estimators : list of estimators, or list of (str, estimator tuples) - A list of estimators or list of (str, estimator) tuples to be converted. + estimators : list of estimators, or list of (str, estimator) tuples. + A ``list`` of estimators or ``list`` of (``str``, ``estimator``) tuples + to be converted. clone_estimators : boolean, default=True. - Whether to return clone of estimators in ``estimators`` (True) or - references (False). + Whether to return clone of estimators in ``estimators`` (``True``) or + references (``False``). Returns ------- estimator_tuples : list of (str, estimator) tuples - If estimators was a list of (str, estimator) tuples, then identical/cloned + If ``estimators`` was a ``list`` of (``str``, ``estimator``) tuples, then + identical/cloned to ``estimators``. - if was a list of estimators or mixed, then unique str are generated to + if was a ``list`` of ``estimators`` or mixed, then unique ``str`` + are generated to create tuples. """ cloned_ests = []