From ea19cda60799bd0e76c14ad2dbfa6fa8eabd1184 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Wed, 14 May 2025 14:01:56 +0200 Subject: [PATCH 1/8] WIP. filters already working and tested. --- molpipeline/estimators/samplers/__init__.py | 0 .../estimators/samplers/oversampling.py | 129 ++++++++ .../estimators/samplers/stochastic_filter.py | 236 ++++++++++++++ .../estimators/samplers/stochastic_sampler.py | 143 +++++++++ molpipeline/imbalance/__init__.py | 1 + molpipeline/imbalance/oversampling.py | 1 + .../test_estimators/test_samplers/__init__.py | 1 + .../test_samplers/test_oversampling.py | 298 ++++++++++++++++++ .../test_samplers/test_stochastic_filter.py | 267 ++++++++++++++++ .../test_samplers/test_stochastic_sampler.py | 1 + 10 files changed, 1077 insertions(+) create mode 100644 molpipeline/estimators/samplers/__init__.py create mode 100644 molpipeline/estimators/samplers/oversampling.py create mode 100644 molpipeline/estimators/samplers/stochastic_filter.py create mode 100644 molpipeline/estimators/samplers/stochastic_sampler.py create mode 100644 molpipeline/imbalance/__init__.py create mode 100644 molpipeline/imbalance/oversampling.py create mode 100644 tests/test_estimators/test_samplers/__init__.py create mode 100644 tests/test_estimators/test_samplers/test_oversampling.py create mode 100644 tests/test_estimators/test_samplers/test_stochastic_filter.py create mode 100644 tests/test_estimators/test_samplers/test_stochastic_sampler.py diff --git a/molpipeline/estimators/samplers/__init__.py b/molpipeline/estimators/samplers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/molpipeline/estimators/samplers/oversampling.py b/molpipeline/estimators/samplers/oversampling.py new file mode 100644 index 00000000..1cf8b15b --- /dev/null +++ b/molpipeline/estimators/samplers/oversampling.py @@ -0,0 +1,129 @@ +"""Oversampler that selects samples based on group sizes.""" + +import numpy as np +import numpy.typing as npt +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import check_random_state, check_X_y + +from molpipeline.estimators.samplers.stochastic_filter import GroupSizeFilter + + +class GroupRandomOversampler(BaseEstimator, TransformerMixin): + """Random oversampler that operates on groups with inverse size weighting. + + The oversampling strategy picks minority class samples with a probability inverse + to size of the group they are in. With this strategy minority samples from smaller + groups are more often oversampled. The number of samples picked is the difference + between the number of majority class and minority class samples, such that both + there are the same number of samples for each class. Note that, since only minority + class samples of the groups are oversampled, this strategy can lead to an class + imbalance within groups. + + """ + + def __init__(self, random_state: int | None = None) -> None: + """Create new GroupRandomOversampler. + + Parameters + ---------- + random_state : int, RandomState instance or None, default=None + Controls the randomization of the algorithm. + + """ + self.random_state = random_state + + def transform( + self, + X: npt.NDArray[np.float64], # noqa: N803 + y: npt.NDArray[np.float64], + groups: npt.NDArray[np.float64], + return_groups: bool = False, + ) -> ( + tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] + | tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ] + ): + """Transform X and y to oversampled version of the data set. + + Parameters + ---------- + X : npt.NDArray[np.float64] + Training data. Will be ignored, only present because of compatibility. + y : npt.NDArray[np.float64] + Target values. Must be binary classification labels. + groups : npt.NDArray[np.float64] + Group labels for the samples. + return_groups : bool, default=False + If True, return the groups of the resampled data set. + + Returns + ------- + tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] + | tuple[ + npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64] + ] + Resampled data set (X, y) or (X, y, groups) if return_groups is True. + + Raises + ------ + ValueError + If the length of y does not match the length of groups. + If y is not binary classification labels. + + """ + _, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + rng = check_random_state(self.random_state) + + if len(groups) != y.shape[0]: + raise ValueError( + f"Found groups with size {groups.shape[0]}, but expected {y.shape[0]}", + ) + + u_y, u_y_sizes = np.unique( + y, + return_counts=True, + ) + + expected_number_of_classes = 2 + if len(u_y) != expected_number_of_classes: + raise ValueError("Only support oversampling for binary classification.") + + diff = abs(u_y_sizes[0] - u_y_sizes[1]) + minority_class = u_y[np.argmin(u_y_sizes)] + + filter_policy = GroupSizeFilter(groups) + probs = filter_policy.calculate_probabilities(X, y) + # set probability at majority class to zero + probs[y != minority_class] = 0 + + # normalize probabilities over all samples + probs /= np.sum(probs, dtype=np.float64) + + # sample indices + sampled_indices = rng.choice( + len(X), + size=diff, + replace=True, + p=probs, + ) + + x_resampled = np.empty((X.shape[0] + diff, X.shape[1]), dtype=X.dtype) + x_resampled[: X.shape[0]] = X + x_resampled[X.shape[0] :] = X[sampled_indices] + + y_resampled = np.empty(y.shape[0] + diff, dtype=y.dtype) + y_resampled[: y.shape[0]] = y + y_resampled[y.shape[0] :] = y[sampled_indices] + + if return_groups: + groups_resampled = np.empty(groups.shape[0] + diff, dtype=groups.dtype) + groups_resampled[: groups.shape[0]] = groups + groups_resampled[groups.shape[0] :] = groups[sampled_indices] + return x_resampled, y_resampled, groups_resampled + + # TODO optionally shuffle resampled data? technically not necessary + + return x_resampled, y_resampled diff --git a/molpipeline/estimators/samplers/stochastic_filter.py b/molpipeline/estimators/samplers/stochastic_filter.py new file mode 100644 index 00000000..a893637a --- /dev/null +++ b/molpipeline/estimators/samplers/stochastic_filter.py @@ -0,0 +1,236 @@ +"""Stochastic filters for providing filter probabilities for sampling.""" + +from abc import ABC, abstractmethod +from typing import override + +import numpy as np +import numpy.typing as npt + + +class StochasticFilter(ABC): + """Abstract base class for stochastic filter policies. + + A StochasticFilter assigns sampling probabilities to data points based on + implementation-defined criteria. + """ + + @abstractmethod + def calculate_probabilities( + self, + X: npt.NDArray[np.float64], # noqa: N803 + y: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Calculate probability for each sample. + + Parameters + ---------- + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). + + Returns + ------- + probabilities : npt.NDArray[np.float64] + Probability for each sample of shape (n_samples,) + + """ + + +class GlobalClassBalanceFilter(StochasticFilter): + """Provides probabilities inversely proportional to global class frequencies. + + Samples from minority classes receive higher probability values. + """ + + @override + def calculate_probabilities( + self, + X: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Calculate probability based on inverse class frequencies. + + Note that probabilities over the sample list do not sum to 1. + + Parameters + ---------- + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). + + Returns + ------- + probabilities : npt.NDArray[np.float64] + Probability for each sample of shape (n_samples,). + + """ + _, class_indices, counts = np.unique( + y, + return_inverse=True, + return_counts=True, + ) + + # calculate inverse class frequencies + inv_freqs = 1.0 / counts + + # normalize to probability distribution + normalized_inv_freqs = inv_freqs / inv_freqs.sum() + + # return probabilities in order of samples in the input + return normalized_inv_freqs[class_indices] + + +class LocalGroupClassBalanceFilter(StochasticFilter): + """Filter that returns probabilities based on class balance within groups. + + Samples from minority classes within their group receive higher probabilities. + """ + + def __init__(self, groups: npt.NDArray[np.float64]) -> None: + """Create a new LocalGroupClassBalanceFilter.""" + self._n_samples = groups.shape[0] + self._unique_groups, self._group_indices = np.unique( + groups, + return_inverse=True, + ) + + @override + def calculate_probabilities( + self, + X: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Calculate probability based on inverse class frequencies within groups. + + Note that probabilities over the sample list do not sum to 1. + + Parameters + ---------- + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). + + Returns + ------- + probabilities : npt.NDArray[np.float64] + Probability for each sample of shape (n_samples,). + + Raises + ------ + ValueError + If the length of y does not match the length of groups. + + """ + if self._n_samples != len(y): + raise ValueError("Provided y must have the same length as groups") + + n_samples = self._n_samples + unique_groups = self._unique_groups + group_indices = self._group_indices + + probabilities = np.zeros(n_samples) + + prob_sum = 0 + for group_idx, _ in enumerate(unique_groups): + # Find samples belonging to current group + group_mask = group_indices == group_idx + + # get classes for this group and their counts + group_y = y[group_mask] + _, group_class_indices, group_counts = np.unique( + group_y, + return_inverse=True, + return_counts=True, + ) + + # calculate normalized inverse frequencies + group_inv_freqs = 1.0 / group_counts + + # map each sample to its corresponding inverse frequency + probabilities[group_mask] = group_inv_freqs[group_class_indices] + + # we need to sum the inverse frequencies for normalization + prob_sum += group_inv_freqs.sum() + + # Normalize the final probabilities. We normalize over the inverse frequencies + # over all groups. This sets the scale between groups and classes. Groups with + # fewer samples of a particular class will get higher probabilities than samples + # with many samples of that class. + if prob_sum > 0: + probabilities /= prob_sum + else: + probabilities = np.ones(n_samples) / n_samples + + return probabilities + + +class GroupSizeFilter(StochasticFilter): + """Filter that returns probabilities inversely proportional to group sizes. + + Samples from smaller groups receive higher probabilities. + """ + + def __init__(self, groups: npt.NDArray[np.float64]) -> None: + """Create a new GroupSizeFilter.""" + self.probabilities = self._compute_group_probs(groups) + + @staticmethod + def _compute_group_probs( + groups: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Compute inverse group sizes and normalize them. + + Note that probabilities over the sample list do not sum to 1. + + Parameters + ---------- + groups : npt.NDArray[np.float64] + Group labels for the samples of shape (n_samples,). + + Returns + ------- + probabilities : npt.NDArray[np.float64] + Probability for each sample of shape (n_samples,). + + """ + _, group_indices, group_counts = np.unique( + groups, + return_counts=True, + return_inverse=True, + ) + inv_sizes = 1.0 / group_counts + normalized_inv_sizes = inv_sizes / inv_sizes.sum() + return normalized_inv_sizes[group_indices] + + @override + def calculate_probabilities( + self, + X: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Calculate probability based on inverse group sizes. + + Parameters + ---------- + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). + + Returns + ------- + probabilities : array-like of shape (n_samples,) + Probability for each sample. + + Raises + ------ + ValueError + If the length of y does not match the length of groups. + + """ + if len(self.probabilities) != len(y): + raise ValueError("Provided y must be the same length as groups") + return self.probabilities diff --git a/molpipeline/estimators/samplers/stochastic_sampler.py b/molpipeline/estimators/samplers/stochastic_sampler.py new file mode 100644 index 00000000..0ee9fa96 --- /dev/null +++ b/molpipeline/estimators/samplers/stochastic_sampler.py @@ -0,0 +1,143 @@ +from collections.abc import Sequence + +import numpy as np +import numpy.typing as npt +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import check_random_state, check_X_y + + +class StochasticSampler(BaseEstimator, TransformerMixin): + """Sampler that uses multiple stochastic filters to determine sampling probabilities. + + Parameters + ---------- + filters : list of StochasticFilter + List of filter policies to apply. + n_samples : int or None, default=None + Number of samples to generate. If None, will match the size of X. + combination_method : {'product', 'mean', 'min', 'max'}, default='product' + Method to combine probabilities from multiple filters. + random_state : int, RandomState instance or None, default=None + Controls the randomization of the algorithm. + + """ + + def __init__( + self, + filters: Sequence[StochasticFilter], + n_samples: int | None = None, + combination_method: str = "product", + random_state: int | None = None, + ): + self.filters = filters + self.n_samples = n_samples + self.combination_method = combination_method + self.random_state = random_state + + # Validate combination method + valid_methods = ["product", "mean", "min", "max"] + if self.combination_method not in valid_methods: + raise ValueError(f"combination_method must be one of {valid_methods}") + + def _combine_probabilities(self, probabilities_list: npt.NDArray) -> npt.NDArray: + """Combine probabilities from multiple filters. + + Parameters + ---------- + probabilities_list : list of array-like + List of probability arrays from each filter. + + Returns + ------- + combined_probs : array-like + Combined probabilities for each sample. + + """ + if not probabilities_list: + raise ValueError("No probabilities to combine") + + if len(probabilities_list) == 1: + return probabilities_list[0] + + if self.combination_method == "product": + combined = np.prod(probabilities_list, axis=0) + + elif self.combination_method == "mean": + combined = np.mean(probabilities_list, axis=0) + + elif self.combination_method == "min": + combined = np.min(probabilities_list, axis=0) + + elif self.combination_method == "max": + combined = np.max(probabilities_list, axis=0) + else: + raise AssertionError("Invalid combination method") + + # # Normalize to ensure probabilities sum to 1 + # if combined.sum() > 0: + # combined = combined / combined.sum() + # else: + # # If all zeros, use uniform distribution + # combined = np.ones_like(combined) / len(combined) + + return combined + + def transform( + self, + X: npt.NDArray, + y: npt.NDArray, + ) -> tuple[npt.NDArray, npt.NDArray]: + """Apply filters and sample from the input data. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + y : array-like of shape (n_samples,) + The target values. + + Returns + ------- + X_sampled : array-like of shape (n_samples_new, n_features) + The sampled input. + y_sampled : array-like of shape (n_samples_new,) + The sampled targets. + + """ + X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + rng = check_random_state(self.random_state) + + # Determine the number of samples to generate + n_samples = self.n_samples if self.n_samples is not None else len(X) + + # Apply each filter to get probabilities + filter_probabilities = np.zeros((len(X), len(self.filters))) + for filter_idx, filter_policy in enumerate(self.filters): + filter_probabilities[:, filter_idx] = filter_policy.calculate_probabilities( + X, + y, + ) + + # Combine probabilities + combined_probabilities = self._combine_probabilities(filter_probabilities) + + # normalize combined probabilities because numpy's choices needs it + if combined_probabilities.sum() > 0: + combined_probabilities /= combined_probabilities.sum() + else: + # if all zeros, use uniform distribution + combined_probabilities = np.ones_like(combined_probabilities) / len( + combined_probabilities, + ) + + sample_probabilities = 1 - combined_probabilities + # Sample indices based on combined probabilities + indices = rng.choice( + len(X), + size=n_samples, + replace=True, + p=sample_probabilities, + ) + + # Return sampled data + return X[indices], y[indices] diff --git a/molpipeline/imbalance/__init__.py b/molpipeline/imbalance/__init__.py new file mode 100644 index 00000000..56051bcf --- /dev/null +++ b/molpipeline/imbalance/__init__.py @@ -0,0 +1 @@ +"""Imbalance module.""" diff --git a/molpipeline/imbalance/oversampling.py b/molpipeline/imbalance/oversampling.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/molpipeline/imbalance/oversampling.py @@ -0,0 +1 @@ + diff --git a/tests/test_estimators/test_samplers/__init__.py b/tests/test_estimators/test_samplers/__init__.py new file mode 100644 index 00000000..7d1794d4 --- /dev/null +++ b/tests/test_estimators/test_samplers/__init__.py @@ -0,0 +1 @@ +"""Test module for samplers.""" diff --git a/tests/test_estimators/test_samplers/test_oversampling.py b/tests/test_estimators/test_samplers/test_oversampling.py new file mode 100644 index 00000000..02b0652c --- /dev/null +++ b/tests/test_estimators/test_samplers/test_oversampling.py @@ -0,0 +1,298 @@ +"""Test oversamplers.""" + +import unittest + +import numpy as np +import numpy.typing as npt + +from molpipeline.estimators.samplers.oversampling import GroupRandomOversampler + + +class TestGroupRandomOversampler(unittest.TestCase): + """Test oversamplers.""" + + def _assert_resampling_results( + self, + x_matrix: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + x_resampled: npt.NDArray[np.float64], + y_resampled: npt.NDArray[np.float64], + ) -> None: + """Verify resampling results. + + Parameters + ---------- + x_matrix : npt.NDArray[np.float64] + Original features. + y : npt.NDArray[np.float64] + Original target values. + x_resampled : npt.NDArray[np.float64] + Resampled features. + y_resampled : npt.NDArray[np.float64] + Resampled target values. + + """ + n_samples = x_matrix.shape[0] + + # test shapes + self.assertEqual(x_resampled.shape[1], x_matrix.shape[1]) + self.assertEqual(x_resampled.shape[0], y_resampled.shape[0]) + + y_unique, y_counts = np.unique(y, return_counts=True) + y_resampled_unique, y_resampled_counts = np.unique( + y_resampled, + return_counts=True, + ) + + self.assertEqual(len(y_unique), len(y_resampled_unique)) + n_minority_original, n_majority_original = sorted(y_counts) + n_expected_new = n_majority_original - n_minority_original + + # test the resampled data has the expected number of samples + self.assertEqual(len(y_resampled), n_samples + n_expected_new) + # test that the resampled data has the expected number of classes + self.assertTrue(np.array_equal(y_unique, np.unique(y_resampled))) + # test that the resampled data has balanced classes + self.assertEqual(y_resampled_counts[0], y_resampled_counts[1]) + + # test that all oversampled samples are copies of original ones + u_x_matrix = np.unique(x_matrix, axis=0) + u_x_resampled = np.unique(x_resampled, axis=0) + self.assertTrue(np.array_equal(u_x_matrix, u_x_resampled)) + + def test_oversampling(self) -> None: + """Test oversamplers.""" + sampler = GroupRandomOversampler(random_state=12345) + + y = np.array([0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + x_matrix = np.zeros((y.shape[0], 2)) + groups = np.array([1, 1, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + + x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) + + def test_invalid_x_y_size(self) -> None: + """Test that an error is raised when group size doesn't match sample size.""" + sampler = GroupRandomOversampler(random_state=42) + + y = np.array([0, 0, 1, 0, 1]) + # X has different length than y + x_matrix = np.zeros((y.shape[0] - 1, 2)) + groups = np.array([1, 1, 1, 2, 2]) + + with self.assertRaises(ValueError) as context: + sampler.transform(x_matrix, y, groups) + + self.assertTrue( + "Found input variables with inconsistent numbers of samples" + in str(context.exception), + ) + + def test_invalid_group_size(self) -> None: + """Test that an error is raised when group size doesn't match sample size.""" + sampler = GroupRandomOversampler(random_state=42) + + y = np.array([0, 0, 1, 0, 1]) + x_matrix = np.zeros((y.shape[0], 2)) + # Different length than X and y + groups = np.array([1, 1, 1, 2]) + + with self.assertRaises(ValueError) as context: + sampler.transform(x_matrix, y, groups) + + self.assertTrue("Found groups with size" in str(context.exception)) + + def test_non_binary_classification(self) -> None: + """Test that an error is raised for non-binary classification problems.""" + sampler = GroupRandomOversampler(random_state=42) + + # Multi-class problem with 3 classes + y = np.array([0, 1, 2, 0, 1, 2]) + x_matrix = np.zeros((y.shape[0], 2)) + groups = np.array([1, 1, 1, 2, 2, 2]) + + with self.assertRaises(ValueError): + sampler.transform(x_matrix, y, groups) + + def test_minority_class_in_single_group(self) -> None: + """Test when all minority class samples appear in only one group.""" + sampler = GroupRandomOversampler(random_state=42) + + n_samples = 100 + n_minority = 10 + + y = np.zeros(n_samples) + y[:n_minority] = 1 # Minority class (class 1) is first 10 samples + x_matrix = np.zeros((n_samples, 2)) + + # All minority samples are in group 1, majority samples distributed across + # groups. + groups = np.ones(n_samples) + groups[n_minority : n_minority + 30] = 2 + groups[n_minority + 30 : n_minority + 60] = 3 + groups[n_minority + 60 :] = 4 + + x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) + + def test_only_one_group(self) -> None: + """Test when all samples belong to the same group.""" + sampler = GroupRandomOversampler(random_state=42) + + y = np.array([0, 1, 0, 1, 0, 0]) + n_samples = y.shape[0] + x_matrix = np.zeros((n_samples, 2)) + groups = np.ones(n_samples) # all samples belong to the same group + + x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) + + def test_already_balanced(self) -> None: + """Test when the dataset is already balanced.""" + sampler = GroupRandomOversampler(random_state=42) + + y = np.array([0, 1, 0, 1]) + n_samples = y.shape[0] + x_matrix = np.zeros((n_samples, 2)) + groups = np.array([1, 1, 2, 2]) + + x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) + + def test_deterministic_behavior(self) -> None: + """Test that the same random_state produces consistent results.""" + y = np.array([0, 0, 0, 1, 1]) + x_matrix = np.zeros((y.shape[0], 2)) + groups = np.array([1, 1, 2, 2, 3]) + + # Run with same random_state twice + sampler1 = GroupRandomOversampler(random_state=1234) + sampler2 = GroupRandomOversampler(random_state=1234) + + x_resampled1, y_resampled1 = sampler1.transform(x_matrix, y, groups) + x_resampled2, y_resampled2 = sampler2.transform(x_matrix, y, groups) + + # Results should be identical + self.assertTrue(np.array_equal(x_resampled1, x_resampled2)) + self.assertTrue(np.array_equal(y_resampled1, y_resampled2)) + + def test_with_empty_groups(self) -> None: + """Test behavior when some groups have no samples of the minority class.""" + sampler = GroupRandomOversampler(random_state=42) + + magic_value = 99 + + y = np.array([0, 0, 0, 1, 1]) + x_matrix = np.zeros((y.shape[0], 2)) + x_matrix[2, :] = magic_value # Mark data of group 3 sample + # Group 3 has no minority samples + groups = np.array([1, 2, 3, 1, 2]) + + x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) + + # check that there is only one sample with values of 99 + self.assertEqual((x_resampled == magic_value).all(axis=1).sum(), 1) + + @staticmethod + def _construct_test_data_inverse_probability_sampling() -> tuple[ + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + ]: + """Construct test data for inverse probability sampling. + + Returns + ------- + tuple[npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64], + npt.NDArray[np.float64]] + x_matrix : Feature matrix. + y : Target values. + groups : Group labels for the samples. + expected_probs : Expected sampling probabilities based on inverse group + sizes. + + """ + # Create a dataset with specific group sizes + # Group 1: 10 samples (5 minority) + # Group 2: 20 samples (5 minority) + # Group 3: 40 samples (5 minority) + # The smaller groups should be sampled more frequently + + # Generate data + n_group1, n_group2, n_group3 = 10, 20, 40 + n_minority_each = 5 + + # Create dataset - majority class is 0, minority class is 1 + y = np.zeros(n_group1 + n_group2 + n_group3) + # Set minority samples in each group + y[:n_minority_each] = 1 # Group 1 minority + y[n_group1 : n_group1 + n_minority_each] = 1 # Group 2 minority + y[n_group1 + n_group2 : n_group1 + n_group2 + n_minority_each] = ( + 1 # Group 3 minority + ) + + x_matrix = np.zeros((len(y), 2)) + + # Assign group labels + groups = np.zeros(len(y)) + groups[:n_group1] = 1 + groups[n_group1 : n_group1 + n_group2] = 2 + groups[n_group1 + n_group2 :] = 3 + + # Calculate expected sampling probabilities based on inverse group sizes + group_sizes = np.array([n_group1, n_group2, n_group3]) + inv_weights = 1.0 / group_sizes + expected_probs = inv_weights / inv_weights.sum() + + return x_matrix, y, groups, expected_probs + + def test_inverse_probability_sampling(self) -> None: + """Test that samples are selected based on inverse group size probabilities.""" + n_runs = 1000 # Large number of runs for statistical significance + + x_matrix, y, groups, expected_probs = ( + self._construct_test_data_inverse_probability_sampling() + ) + + # Track the sampled groups across multiple runs + sampled_counts = np.zeros(3) + + for _ in range(n_runs): + sampler = GroupRandomOversampler(random_state=1234) + _, _, groups_resampled = sampler.transform( + x_matrix, + y, + groups, + return_groups=True, + ) + + # Count newly added samples by group + new_samples = groups_resampled[len(groups) :] + unique_groups, counts = np.unique(new_samples, return_counts=True) + + for group, count in zip(unique_groups, counts, strict=True): + sampled_counts[int(group) - 1] += count + + # Calculate observed sampling probabilities + observed_probs = sampled_counts / sampled_counts.sum() + + # Check that observed probabilities are close to expected probabilities + # Allow some tolerance for statistical variation + for i in range(3): + self.assertAlmostEqual( + observed_probs[i], + expected_probs[i], + delta=0.05, # 5% tolerance + msg=f"Group {i + 1} sampling probability incorrect. " + f"Expected ~{expected_probs[i]:.2f}, got {observed_probs[i]:.2f}", + ) + + # smaller groups should have higher sampling rates + self.assertTrue( + observed_probs[0] > observed_probs[1] > observed_probs[2], + "Smaller groups should have higher sampling probabilities", + ) diff --git a/tests/test_estimators/test_samplers/test_stochastic_filter.py b/tests/test_estimators/test_samplers/test_stochastic_filter.py new file mode 100644 index 00000000..c223e2a3 --- /dev/null +++ b/tests/test_estimators/test_samplers/test_stochastic_filter.py @@ -0,0 +1,267 @@ +"""Tests for stochastic filters.""" + +import unittest + +import numpy as np + +from molpipeline.estimators.samplers.stochastic_filter import ( + GlobalClassBalanceFilter, + GroupSizeFilter, + LocalGroupClassBalanceFilter, +) + + +class TestGlobalClassBalanceFilter(unittest.TestCase): + """Test stochastic sampler.""" + + def test_global_class_balance_filter_multiclass(self) -> None: + """Test GlobalClassBalanceFilter with multiple classes.""" + filter_policy = GlobalClassBalanceFilter() + x_matrix = np.zeros((10, 2)) + y = np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 3]) + + # Class counts: class 0=2, class 1=4, class 2=3, class 3=1 + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # Calculate expected probabilities: + # inverse frequencies are [1/2, 1/4, 1/3, 1] + # sum = 1/2 + 1/4 + 1/3 + 1 = 25/12 + # normalized values are [1/2, 1/4, 1/3, 1] / (25/12) = [6/25, 3/25, 4/25, 12/25] + expected = np.array( + [ + 6 / 25, + 6 / 25, + 3 / 25, + 3 / 25, + 3 / 25, + 3 / 25, + 4 / 25, + 4 / 25, + 4 / 25, + 12 / 25, + ], + ) + self.assertTrue(np.allclose(probs, expected)) + + def test_global_class_balance_filter_single_class(self) -> None: + """Test GlobalClassBalanceFilter with only one class.""" + filter_policy = GlobalClassBalanceFilter() + x_matrix = np.zeros((5, 2)) + y = np.ones(5) # All samples belong to class 1 + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # With only one class, all samples should have equal probability of 1 + expected = np.ones(5) + self.assertTrue(np.allclose(probs, expected)) + + def test_global_class_balance_filter_extreme_imbalance(self) -> None: + """Test GlobalClassBalanceFilter with extremely imbalanced classes.""" + filter_policy = GlobalClassBalanceFilter() + x_matrix = np.zeros((101, 2)) + y = np.zeros(101) + y[0] = 1 # Only one sample of class 1 + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # inverse frequencies are [1/100, 1] + # sum = 1/100 + 1 = 1.01 + expected = np.array([1 / 1.01, *([(1 / 100) / 1.01] * 100)]) + self.assertTrue(np.allclose(probs, expected)) + + +class TestLocalGroupClassBalanceFilter(unittest.TestCase): + """Test LocalGroupClassBalanceFilter.""" + + def test_multiple_groups_multiple_classes(self) -> None: + """Test probability calculation with multiple groups and classes.""" + # Groups: 3 groups with varying class distribution + # Group 0: class 0=1, class 1=2 (inv_freq: [1, 1/2]) + # Group 1: class 0=2, class 1=1 (inv_freq: [1/2, 1]) + # Group 2: class 0=1, class 1=1 (inv_freq: [1, 1]) + groups = np.array([0, 0, 0, 1, 1, 1, 2, 2]) + y = np.array([0, 1, 1, 0, 0, 1, 0, 1]) + + filter_policy = LocalGroupClassBalanceFilter(groups) + x_matrix = np.zeros((len(groups), 2)) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # Expected probabilities after normalization: + # Total sum of inv_freqs: 1 + 1/2 + 1/2 + 1 + 1 + 1 = 5 + # Group 0: [1/5, 1/10, 1/10] + # Group 1: [1/10, 1/10, 1/5] + # Group 2: [1/5, 1/5] + expected = np.array( + [1 / 5, 1 / 10, 1 / 10, 1 / 10, 1 / 10, 1 / 5, 1 / 5, 1 / 5], + ) + self.assertTrue(np.allclose(probs, expected)) + + def test_single_group_multiple_classes(self) -> None: + """Test with a single group containing multiple classes.""" + groups = np.array([1, 1, 1, 1]) + y = np.array([0, 0, 1, 2]) # Class 0=2, Class 1=1, Class 2=1 + + filter_policy = LocalGroupClassBalanceFilter(groups) + x_matrix = np.zeros((4, 2)) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # Expected inverse frequencies are [1/2, 1, 1] + # Sum = 1/2 + 1 + 1 = 2.5 + # Normalized values are [1/5, 1/5, 2/5, 2/5] + expected = np.array([1 / 5, 1 / 5, 2 / 5, 2 / 5]) + self.assertTrue(np.allclose(probs, expected)) + + def test_multiple_groups_single_class(self) -> None: + """Test with multiple groups each containing a single class.""" + groups = np.array([1, 1, 2, 2, 3, 3]) + y = np.ones(6) # All samples are class 1 + + filter_policy = LocalGroupClassBalanceFilter(groups) + x_matrix = np.zeros((6, 2)) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # Expected inverse frequencies are [1/2, 1/2, 1/2] + # Sum = 1/2 + 1/2 + 1/2 = 1.5 + # Normalized values are [1/3, 1/3, 1/3, 1/3, 1/3, 1/3] + + # With single class in each group, all samples should have equal probability + expected = np.ones(6) / 3 + self.assertTrue(np.allclose(probs, expected)) + + def test_extreme_imbalance_within_groups(self) -> None: + """Test with extreme class imbalance within groups.""" + groups = np.array([1, 1, 1, 1, 1, 2, 2, 2]) + y = np.array([0, 0, 0, 0, 1, 0, 0, 1]) # Group 1: 4:1, Group 2: 2:1 + + filter_policy = LocalGroupClassBalanceFilter(groups) + x_matrix = np.zeros((8, 2)) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # Group 1: inv_freqs = [1/4, 1] = [0.25, 1] + # Group 2: inv_freqs = [1/2, 1] = [0.5, 1] + # Sum = 1/4 + 1 + 1/2 + 1 = 2.75 + # Normalized values are [0.25/2.75, 0.25/2.75, 0.25/2.75, 0.25/2.75, 1/2.75, + # 0.5/2.75, 0.5/2.75, 1/2.75] + expected = np.array( + [ + 0.25 / 2.75, + 0.25 / 2.75, + 0.25 / 2.75, + 0.25 / 2.75, + 1 / 2.75, + 0.5 / 2.75, + 0.5 / 2.75, + 1 / 2.75, + ], + ) + self.assertTrue(np.allclose(probs, expected)) + + def test_error_mismatched_lengths(self) -> None: + """Test error handling for mismatched length between groups and y.""" + groups = np.array([1, 1, 2, 2]) + filter_policy = LocalGroupClassBalanceFilter(groups) + x_matrix = np.zeros((5, 2)) + y = np.array([0, 0, 1, 1, 1]) + + with self.assertRaises(ValueError): + filter_policy.calculate_probabilities(x_matrix, y) + + def test_empty_groups(self) -> None: + """Test handling of empty input.""" + groups = np.array([]) + filter_policy = LocalGroupClassBalanceFilter(groups) + x_matrix = np.zeros((0, 2)) + y = np.array([]) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + self.assertEqual(len(probs), 0) + + +class TestGroupSizeFilter(unittest.TestCase): + """Test GroupSizeFilter.""" + + def test_group_size_filter_basic(self) -> None: + """Test GroupSizeFilter with groups of different sizes.""" + # Groups: size 3, 2, 5 + groups = np.array([1, 1, 1, 2, 2, 3, 3, 3, 3, 3]) + filter_policy = GroupSizeFilter(groups) + x_matrix = np.zeros((10, 2)) + y = np.ones(10) # All samples belong to class 1 but doesn't matter + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # Expected inverse sizes are [1/3, 1/2, 1/5] + # Sum = 1/3 + 1/2 + 1/5 = 31/30 + # Normalized values are [1/3, 1/2, 1/5] / (31/30) = [10/31, 15/31, 6/31] + expected = np.array( + [ + 10 / 31, + 10 / 31, + 10 / 31, + 15 / 31, + 15 / 31, + 6 / 31, + 6 / 31, + 6 / 31, + 6 / 31, + 6 / 31, + ], + ) + self.assertTrue(np.allclose(probs, expected)) + + def test_group_size_filter_single_group(self) -> None: + """Test GroupSizeFilter with only one group.""" + groups = np.ones(5) + filter_policy = GroupSizeFilter(groups) + x_matrix = np.zeros((5, 2)) + y = np.ones(5) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # With one group, all samples should have equal probability + expected = np.ones(5) + self.assertTrue(np.allclose(probs, expected)) + + def test_group_size_filter_equal_sizes(self) -> None: + """Test GroupSizeFilter with equally sized groups.""" + groups = np.array([1, 1, 2, 2, 3, 3]) + filter_policy = GroupSizeFilter(groups) + x_matrix = np.zeros((6, 2)) + y = np.ones(6) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # All groups have equal size, so all samples should have equal probability of + # 1 / 3 + expected = np.zeros(6) + (1 / 3) + self.assertTrue(np.allclose(probs, expected)) + + def test_group_size_filter_extreme_imbalance(self) -> None: + """Test GroupSizeFilter with extremely imbalanced group sizes.""" + groups = np.zeros(101) + groups[0] = 1 # One sample in group 1, 100 samples in group 0 + + filter_policy = GroupSizeFilter(groups) + x_matrix = np.ones((101, 2)) + y = np.zeros(101) + + probs = filter_policy.calculate_probabilities(x_matrix, y) + + # inverse frequencies are [1/100, 1] + # sum = 1/100 + 1 = 1.01 + expected = np.array([1 / 1.01, *([(1 / 100) / 1.01] * 100)]) + self.assertTrue(np.allclose(probs, expected)) + + def test_group_size_filter_error_handling(self) -> None: + """Test GroupSizeFilter error handling for mismatched lengths.""" + groups = np.array([1, 1, 2, 2]) + filter_policy = GroupSizeFilter(groups) + x_matrix = np.zeros((5, 2)) + y = np.ones(5) + + with self.assertRaises(ValueError): + filter_policy.calculate_probabilities(x_matrix, y) diff --git a/tests/test_estimators/test_samplers/test_stochastic_sampler.py b/tests/test_estimators/test_samplers/test_stochastic_sampler.py new file mode 100644 index 00000000..2a7c8e19 --- /dev/null +++ b/tests/test_estimators/test_samplers/test_stochastic_sampler.py @@ -0,0 +1 @@ +"""Tests for StochasticSampler.""" From a7a88753716c32b1b2238c3cebffa95e6b4173e0 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Wed, 14 May 2025 14:07:30 +0200 Subject: [PATCH 2/8] delete unused imbalance dir --- molpipeline/imbalance/__init__.py | 1 - molpipeline/imbalance/oversampling.py | 1 - 2 files changed, 2 deletions(-) delete mode 100644 molpipeline/imbalance/__init__.py delete mode 100644 molpipeline/imbalance/oversampling.py diff --git a/molpipeline/imbalance/__init__.py b/molpipeline/imbalance/__init__.py deleted file mode 100644 index 56051bcf..00000000 --- a/molpipeline/imbalance/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Imbalance module.""" diff --git a/molpipeline/imbalance/oversampling.py b/molpipeline/imbalance/oversampling.py deleted file mode 100644 index 8b137891..00000000 --- a/molpipeline/imbalance/oversampling.py +++ /dev/null @@ -1 +0,0 @@ - From b0b6b7fdd33283d03e16350da2c7a11423ab10ba Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Thu, 15 May 2025 14:29:30 +0200 Subject: [PATCH 3/8] finished oversampling and filters --- molpipeline/estimators/samplers/__init__.py | 1 + .../estimators/samplers/oversampling.py | 64 ++++++++--- .../estimators/samplers/stochastic_filter.py | 32 ++++-- .../estimators/samplers/stochastic_sampler.py | 44 +++---- .../test_samplers/test_oversampling.py | 107 +++++++++++++----- 5 files changed, 177 insertions(+), 71 deletions(-) diff --git a/molpipeline/estimators/samplers/__init__.py b/molpipeline/estimators/samplers/__init__.py index e69de29b..8cd7b5d0 100644 --- a/molpipeline/estimators/samplers/__init__.py +++ b/molpipeline/estimators/samplers/__init__.py @@ -0,0 +1 @@ +"""Samplers to sample data with different strategies.""" diff --git a/molpipeline/estimators/samplers/oversampling.py b/molpipeline/estimators/samplers/oversampling.py index 1cf8b15b..3c628e2f 100644 --- a/molpipeline/estimators/samplers/oversampling.py +++ b/molpipeline/estimators/samplers/oversampling.py @@ -9,18 +9,21 @@ class GroupRandomOversampler(BaseEstimator, TransformerMixin): - """Random oversampler that operates on groups with inverse size weighting. + """Random oversampler that samples proportional to inverse group sizes. The oversampling strategy picks minority class samples with a probability inverse to size of the group they are in. With this strategy minority samples from smaller groups are more often oversampled. The number of samples picked is the difference - between the number of majority class and minority class samples, such that both - there are the same number of samples for each class. Note that, since only minority - class samples of the groups are oversampled, this strategy can lead to an class - imbalance within groups. + between the number of majority class and minority class samples, such that the + returned data sets is balanced, i.e. has the same number of samples for each class. + Note that, since only minority class samples of the groups are oversampled, this + strategy can lead to a class imbalance within groups. """ + # Currently, only binary classification is supported. + _EXPECTED_NUMBER_OF_CLASSES = 2 + def __init__(self, random_state: int | None = None) -> None: """Create new GroupRandomOversampler. @@ -32,9 +35,45 @@ def __init__(self, random_state: int | None = None) -> None: """ self.random_state = random_state + @staticmethod + def _calculate_probabilities( + X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name + y: npt.NDArray[np.float64], + groups: npt.NDArray[np.float64], + minority_class: int, + ) -> npt.NDArray[np.float64]: + """Calculate sampling probabilities. + + Parameters + ---------- + X : npt.NDArray[np.float64] + Training data. Will be ignored, only present because of compatibility. + y : npt.NDArray[np.float64] + Target values. Must be binary classification labels. + groups : npt.NDArray[np.float64] + Group labels for the samples. + minority_class : int + The minority class label. + + Returns + ------- + npt.NDArray[np.float64] + Sampling probabilities for each sample. + + """ + filter_policy = GroupSizeFilter(groups) + probs = filter_policy.calculate_probabilities(X, y) + # set probability of majority class(es) to zero + probs[y != minority_class] = 0 + + # normalize probabilities over all samples + probs /= np.sum(probs, dtype=np.float64) + + return probs + def transform( self, - X: npt.NDArray[np.float64], # noqa: N803 + X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name y: npt.NDArray[np.float64], groups: npt.NDArray[np.float64], return_groups: bool = False, @@ -87,20 +126,13 @@ def transform( return_counts=True, ) - expected_number_of_classes = 2 - if len(u_y) != expected_number_of_classes: + if len(u_y) != self._EXPECTED_NUMBER_OF_CLASSES: raise ValueError("Only support oversampling for binary classification.") diff = abs(u_y_sizes[0] - u_y_sizes[1]) minority_class = u_y[np.argmin(u_y_sizes)] - filter_policy = GroupSizeFilter(groups) - probs = filter_policy.calculate_probabilities(X, y) - # set probability at majority class to zero - probs[y != minority_class] = 0 - - # normalize probabilities over all samples - probs /= np.sum(probs, dtype=np.float64) + probs = self._calculate_probabilities(X, y, groups, minority_class) # sample indices sampled_indices = rng.choice( @@ -124,6 +156,4 @@ def transform( groups_resampled[groups.shape[0] :] = groups[sampled_indices] return x_resampled, y_resampled, groups_resampled - # TODO optionally shuffle resampled data? technically not necessary - return x_resampled, y_resampled diff --git a/molpipeline/estimators/samplers/stochastic_filter.py b/molpipeline/estimators/samplers/stochastic_filter.py index a893637a..e0bdfc4a 100644 --- a/molpipeline/estimators/samplers/stochastic_filter.py +++ b/molpipeline/estimators/samplers/stochastic_filter.py @@ -7,7 +7,7 @@ import numpy.typing as npt -class StochasticFilter(ABC): +class StochasticFilter(ABC): # pylint: disable=too-few-public-methods """Abstract base class for stochastic filter policies. A StochasticFilter assigns sampling probabilities to data points based on @@ -17,7 +17,7 @@ class StochasticFilter(ABC): @abstractmethod def calculate_probabilities( self, - X: npt.NDArray[np.float64], # noqa: N803 + X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name y: npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """Calculate probability for each sample. @@ -37,7 +37,9 @@ def calculate_probabilities( """ -class GlobalClassBalanceFilter(StochasticFilter): +class GlobalClassBalanceFilter( + StochasticFilter, +): # pylint: disable=too-few-public-methods """Provides probabilities inversely proportional to global class frequencies. Samples from minority classes receive higher probability values. @@ -82,14 +84,23 @@ def calculate_probabilities( return normalized_inv_freqs[class_indices] -class LocalGroupClassBalanceFilter(StochasticFilter): +class LocalGroupClassBalanceFilter( + StochasticFilter, +): # pylint: disable=too-few-public-methods """Filter that returns probabilities based on class balance within groups. Samples from minority classes within their group receive higher probabilities. """ def __init__(self, groups: npt.NDArray[np.float64]) -> None: - """Create a new LocalGroupClassBalanceFilter.""" + """Create a new LocalGroupClassBalanceFilter. + + Parameters + ---------- + groups : npt.NDArray[np.float64] + Group labels for the samples of shape (n_samples,). + + """ self._n_samples = groups.shape[0] self._unique_groups, self._group_indices = np.unique( groups, @@ -167,14 +178,21 @@ def calculate_probabilities( return probabilities -class GroupSizeFilter(StochasticFilter): +class GroupSizeFilter(StochasticFilter): # pylint: disable=too-few-public-methods """Filter that returns probabilities inversely proportional to group sizes. Samples from smaller groups receive higher probabilities. """ def __init__(self, groups: npt.NDArray[np.float64]) -> None: - """Create a new GroupSizeFilter.""" + """Create a new GroupSizeFilter. + + Parameters + ---------- + groups : npt.NDArray[np.float64] + Group labels for the samples of shape (n_samples,). + + """ self.probabilities = self._compute_group_probs(groups) @staticmethod diff --git a/molpipeline/estimators/samplers/stochastic_sampler.py b/molpipeline/estimators/samplers/stochastic_sampler.py index 0ee9fa96..a93acb4a 100644 --- a/molpipeline/estimators/samplers/stochastic_sampler.py +++ b/molpipeline/estimators/samplers/stochastic_sampler.py @@ -1,3 +1,5 @@ +"""Stochastic Sampler.""" + from collections.abc import Sequence import numpy as np @@ -5,32 +7,31 @@ from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils import check_random_state, check_X_y +from molpipeline.estimators.samplers.stochastic_filter import StochasticFilter -class StochasticSampler(BaseEstimator, TransformerMixin): - """Sampler that uses multiple stochastic filters to determine sampling probabilities. - - Parameters - ---------- - filters : list of StochasticFilter - List of filter policies to apply. - n_samples : int or None, default=None - Number of samples to generate. If None, will match the size of X. - combination_method : {'product', 'mean', 'min', 'max'}, default='product' - Method to combine probabilities from multiple filters. - random_state : int, RandomState instance or None, default=None - Controls the randomization of the algorithm. - """ +class StochasticSampler(BaseEstimator, TransformerMixin): + """Sampler that uses stochastic filters for sampling.""" def __init__( self, filters: Sequence[StochasticFilter], - n_samples: int | None = None, combination_method: str = "product", random_state: int | None = None, ): + """Create a new StochasticSampler. + + Parameters + ---------- + filters : list of StochasticFilter + List of filter policies to apply. + combination_method : {'product', 'mean', 'min', 'max'}, default='product' + Method to combine probabilities from multiple filters. + random_state : int, RandomState instance or None, default=None + Controls the randomization of the algorithm. + + """ self.filters = filters - self.n_samples = n_samples self.combination_method = combination_method self.random_state = random_state @@ -39,7 +40,10 @@ def __init__( if self.combination_method not in valid_methods: raise ValueError(f"combination_method must be one of {valid_methods}") - def _combine_probabilities(self, probabilities_list: npt.NDArray) -> npt.NDArray: + def _combine_probabilities( + self, + probabilities_list: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: """Combine probabilities from multiple filters. Parameters @@ -84,9 +88,9 @@ def _combine_probabilities(self, probabilities_list: npt.NDArray) -> npt.NDArray def transform( self, - X: npt.NDArray, - y: npt.NDArray, - ) -> tuple[npt.NDArray, npt.NDArray]: + X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name + y: npt.NDArray[np.float64], + ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Apply filters and sample from the input data. Parameters diff --git a/tests/test_estimators/test_samplers/test_oversampling.py b/tests/test_estimators/test_samplers/test_oversampling.py index 02b0652c..1578675c 100644 --- a/tests/test_estimators/test_samplers/test_oversampling.py +++ b/tests/test_estimators/test_samplers/test_oversampling.py @@ -18,7 +18,7 @@ def _assert_resampling_results( x_resampled: npt.NDArray[np.float64], y_resampled: npt.NDArray[np.float64], ) -> None: - """Verify resampling results. + """Verify invariants of resampling results. Parameters ---------- @@ -68,7 +68,12 @@ def test_oversampling(self) -> None: x_matrix = np.zeros((y.shape[0], 2)) groups = np.array([1, 1, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) - x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + ( # pylint: disable=unbalanced-tuple-unpacking + x_resampled, + y_resampled, + ) = sampler.transform( # type: ignore[misc] + x_matrix, y, groups, + ) self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) def test_invalid_x_y_size(self) -> None: @@ -132,7 +137,9 @@ def test_minority_class_in_single_group(self) -> None: groups[n_minority + 30 : n_minority + 60] = 3 groups[n_minority + 60 :] = 4 - x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + x_resampled, y_resampled = ( # pylint: disable=unbalanced-tuple-unpacking + sampler.transform(x_matrix, y, groups) # type: ignore[misc] + ) self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) def test_only_one_group(self) -> None: @@ -144,7 +151,9 @@ def test_only_one_group(self) -> None: x_matrix = np.zeros((n_samples, 2)) groups = np.ones(n_samples) # all samples belong to the same group - x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + x_resampled, y_resampled = ( # pylint: disable=unbalanced-tuple-unpacking + sampler.transform(x_matrix, y, groups) # type: ignore[misc] + ) self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) def test_already_balanced(self) -> None: @@ -156,7 +165,9 @@ def test_already_balanced(self) -> None: x_matrix = np.zeros((n_samples, 2)) groups = np.array([1, 1, 2, 2]) - x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + x_resampled, y_resampled = ( # pylint: disable=unbalanced-tuple-unpacking + sampler.transform(x_matrix, y, groups) # type: ignore[misc] + ) self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) def test_deterministic_behavior(self) -> None: @@ -169,8 +180,12 @@ def test_deterministic_behavior(self) -> None: sampler1 = GroupRandomOversampler(random_state=1234) sampler2 = GroupRandomOversampler(random_state=1234) - x_resampled1, y_resampled1 = sampler1.transform(x_matrix, y, groups) - x_resampled2, y_resampled2 = sampler2.transform(x_matrix, y, groups) + x_resampled1, y_resampled1 = ( # pylint: disable=unbalanced-tuple-unpacking + sampler1.transform(x_matrix, y, groups) # type: ignore[misc] + ) + x_resampled2, y_resampled2 = ( # pylint: disable=unbalanced-tuple-unpacking + sampler2.transform(x_matrix, y, groups) # type: ignore[misc] + ) # Results should be identical self.assertTrue(np.array_equal(x_resampled1, x_resampled2)) @@ -188,7 +203,9 @@ def test_with_empty_groups(self) -> None: # Group 3 has no minority samples groups = np.array([1, 2, 3, 1, 2]) - x_resampled, y_resampled = sampler.transform(x_matrix, y, groups) + x_resampled, y_resampled = ( # pylint: disable=unbalanced-tuple-unpacking + sampler.transform(x_matrix, y, groups) # type: ignore[misc] + ) self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) # check that there is only one sample with values of 99 @@ -239,9 +256,9 @@ def _construct_test_data_inverse_probability_sampling() -> tuple[ # Assign group labels groups = np.zeros(len(y)) - groups[:n_group1] = 1 - groups[n_group1 : n_group1 + n_group2] = 2 - groups[n_group1 + n_group2 :] = 3 + groups[:n_group1] = 0 + groups[n_group1 : n_group1 + n_group2] = 1 + groups[n_group1 + n_group2 :] = 2 # Calculate expected sampling probabilities based on inverse group sizes group_sizes = np.array([n_group1, n_group2, n_group3]) @@ -250,24 +267,46 @@ def _construct_test_data_inverse_probability_sampling() -> tuple[ return x_matrix, y, groups, expected_probs - def test_inverse_probability_sampling(self) -> None: - """Test that samples are selected based on inverse group size probabilities.""" - n_runs = 1000 # Large number of runs for statistical significance + @staticmethod + def _generate_observed_probs_inverse_probability_sampling( + n_runs: int, + x_matrix: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], + groups: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Generate expected probabilities for the inverse_probability_sampling test. - x_matrix, y, groups, expected_probs = ( - self._construct_test_data_inverse_probability_sampling() - ) + Parameters + ---------- + n_runs : int + Number of runs to average the probabilities. + x_matrix : npt.NDArray[np.float64] + Feature matrix. + y : npt.NDArray[np.float64] + Target values. + groups : npt.NDArray[np.float64] + Group labels for the samples. + Returns + ------- + npt.NDArray[np.float64] + Observed sampling probabilities for each group. + + """ # Track the sampled groups across multiple runs sampled_counts = np.zeros(3) - + rng = np.random.default_rng(1234) for _ in range(n_runs): - sampler = GroupRandomOversampler(random_state=1234) - _, _, groups_resampled = sampler.transform( - x_matrix, - y, - groups, - return_groups=True, + sampler = GroupRandomOversampler( + random_state=rng.integers(0, np.iinfo(np.int32).max), + ) + _, _, groups_resampled = ( # pylint: disable=unbalanced-tuple-unpacking + sampler.transform( # type: ignore[misc] + x_matrix, + y, + groups, + return_groups=True, + ) ) # Count newly added samples by group @@ -275,10 +314,24 @@ def test_inverse_probability_sampling(self) -> None: unique_groups, counts = np.unique(new_samples, return_counts=True) for group, count in zip(unique_groups, counts, strict=True): - sampled_counts[int(group) - 1] += count + sampled_counts[int(group)] += count - # Calculate observed sampling probabilities - observed_probs = sampled_counts / sampled_counts.sum() + # Calculate and return observed sampling probabilities + return sampled_counts / sampled_counts.sum() + + def test_inverse_probability_sampling(self) -> None: + """Test that samples are selected based on inverse group size probabilities.""" + x_matrix, y, groups, expected_probs = ( + self._construct_test_data_inverse_probability_sampling() + ) + + n_runs = 1000 # Large number of runs for statistical significance + observed_probs = self._generate_observed_probs_inverse_probability_sampling( + n_runs, + x_matrix, + y, + groups, + ) # Check that observed probabilities are close to expected probabilities # Allow some tolerance for statistical variation From 60c3d4a8337197c6be45c01cc6a76cff77113af1 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Fri, 16 May 2025 15:01:58 +0200 Subject: [PATCH 4/8] finish stochastic sampler --- .../estimators/samplers/stochastic_filter.py | 2 + .../estimators/samplers/stochastic_sampler.py | 176 ++++--- .../test_samplers/test_stochastic_sampler.py | 459 ++++++++++++++++++ 3 files changed, 568 insertions(+), 69 deletions(-) diff --git a/molpipeline/estimators/samplers/stochastic_filter.py b/molpipeline/estimators/samplers/stochastic_filter.py index e0bdfc4a..68a79f5c 100644 --- a/molpipeline/estimators/samplers/stochastic_filter.py +++ b/molpipeline/estimators/samplers/stochastic_filter.py @@ -75,6 +75,8 @@ def calculate_probabilities( ) # calculate inverse class frequencies + # could also use power functions like inv_freqs = (1.0 / counts) ** p to + # control the penalization with the parameter p. inv_freqs = 1.0 / counts # normalize to probability distribution diff --git a/molpipeline/estimators/samplers/stochastic_sampler.py b/molpipeline/estimators/samplers/stochastic_sampler.py index a93acb4a..2b7661d5 100644 --- a/molpipeline/estimators/samplers/stochastic_sampler.py +++ b/molpipeline/estimators/samplers/stochastic_sampler.py @@ -1,6 +1,7 @@ """Stochastic Sampler.""" from collections.abc import Sequence +from typing import Literal, Self import numpy as np import numpy.typing as npt @@ -11,12 +12,17 @@ class StochasticSampler(BaseEstimator, TransformerMixin): - """Sampler that uses stochastic filters for sampling.""" + """Sampler that uses stochastic filters for sampling. + + Note that combination_method "product" assumes independence of the filters. + + """ def __init__( self, filters: Sequence[StochasticFilter], - combination_method: str = "product", + n_samples: int, + combination_method: Literal["product", "mean"] = "product", random_state: int | None = None, ): """Create a new StochasticSampler. @@ -25,123 +31,155 @@ def __init__( ---------- filters : list of StochasticFilter List of filter policies to apply. - combination_method : {'product', 'mean', 'min', 'max'}, default='product' + n_samples : int + Number of samples to generate. + combination_method : str, default="product" Method to combine probabilities from multiple filters. + Options are "product" or "mean". random_state : int, RandomState instance or None, default=None Controls the randomization of the algorithm. + Raises + ------ + ValueError + If n_samples is not positive or if combination_method is not recognized. + """ + if n_samples <= 0: + raise ValueError("n_samples must be positive") + if combination_method not in {"product", "mean"}: + raise ValueError( + "combination_method must be either 'product' or 'mean'", + ) self.filters = filters + self.n_samples = n_samples self.combination_method = combination_method - self.random_state = random_state - - # Validate combination method - valid_methods = ["product", "mean", "min", "max"] - if self.combination_method not in valid_methods: - raise ValueError(f"combination_method must be one of {valid_methods}") + self.rng = check_random_state(random_state) def _combine_probabilities( self, - probabilities_list: npt.NDArray[np.float64], + filter_probabilities: npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """Combine probabilities from multiple filters. Parameters ---------- - probabilities_list : list of array-like - List of probability arrays from each filter. + filter_probabilities : npt.NDArray[np.float64] + Array of shape (n_samples, n_filters,). Returns ------- combined_probs : array-like Combined probabilities for each sample. - """ - if not probabilities_list: - raise ValueError("No probabilities to combine") - - if len(probabilities_list) == 1: - return probabilities_list[0] + Raises + ------ + ValueError + If the combination method is not recognized. + """ if self.combination_method == "product": - combined = np.prod(probabilities_list, axis=0) - + combined = np.prod(filter_probabilities, axis=1) elif self.combination_method == "mean": - combined = np.mean(probabilities_list, axis=0) - - elif self.combination_method == "min": - combined = np.min(probabilities_list, axis=0) - - elif self.combination_method == "max": - combined = np.max(probabilities_list, axis=0) + combined = np.mean(filter_probabilities, axis=1) else: - raise AssertionError("Invalid combination method") + raise ValueError(f"Invalid combination method: {self.combination_method}") - # # Normalize to ensure probabilities sum to 1 - # if combined.sum() > 0: - # combined = combined / combined.sum() - # else: - # # If all zeros, use uniform distribution - # combined = np.ones_like(combined) / len(combined) + combined_sum = combined.sum() + if combined_sum > 0: + combined /= combined_sum + else: + # fall back to uniform distribution when all probabilities are 0 + combined = np.ones_like(combined) / len(combined) return combined - def transform( + def calculate_probabilities( self, X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name y: npt.NDArray[np.float64], - ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: - """Apply filters and sample from the input data. + ) -> npt.NDArray[np.float64]: + """Calculate probabilities for each filter. Parameters ---------- - X : array-like of shape (n_samples, n_features) - The input samples. - y : array-like of shape (n_samples,) - The target values. + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). Returns ------- - X_sampled : array-like of shape (n_samples_new, n_features) - The sampled input. - y_sampled : array-like of shape (n_samples_new,) - The sampled targets. + probabilities : npt.NDArray[np.float64] + The calculated probabilities of shape (n_samples,). """ - X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) - rng = check_random_state(self.random_state) - - # Determine the number of samples to generate - n_samples = self.n_samples if self.n_samples is not None else len(X) + x_matrix, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) - # Apply each filter to get probabilities - filter_probabilities = np.zeros((len(X), len(self.filters))) + # get probabilities for each filter + filter_probabilities = np.zeros((len(x_matrix), len(self.filters))) for filter_idx, filter_policy in enumerate(self.filters): filter_probabilities[:, filter_idx] = filter_policy.calculate_probabilities( - X, + x_matrix, y, ) - # Combine probabilities - combined_probabilities = self._combine_probabilities(filter_probabilities) + return self._combine_probabilities(filter_probabilities) - # normalize combined probabilities because numpy's choices needs it - if combined_probabilities.sum() > 0: - combined_probabilities /= combined_probabilities.sum() - else: - # if all zeros, use uniform distribution - combined_probabilities = np.ones_like(combined_probabilities) / len( - combined_probabilities, - ) + def transform( + self, + X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name + y: npt.NDArray[np.float64], + ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """Apply filters and sample from the input data. + + Parameters + ---------- + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). + + Returns + ------- + X_sampled : npt.NDArray[np.float64] + The sampled input of shape (n_samples_new, n_features). + y_sampled : npt.NDArray[np.float64] + The sampled targets of shape (n_samples_new,). + + """ + x_matrix, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + combined_probabilities = self.calculate_probabilities(x_matrix, y) - sample_probabilities = 1 - combined_probabilities # Sample indices based on combined probabilities - indices = rng.choice( - len(X), - size=n_samples, + indices = self.rng.choice( + len(x_matrix), + size=self.n_samples, replace=True, - p=sample_probabilities, + p=combined_probabilities, ) # Return sampled data - return X[indices], y[indices] + return x_matrix[indices], y[indices] + + def fit( + self, + _X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name + _y: npt.NDArray[np.float64], + ) -> Self: + """Maintain scikit-learn API compatibility. + + Parameters + ---------- + X : npt.NDArray[np.float64] + The input samples of shape (n_samples, n_features). + y : npt.NDArray[np.float64] + The target values of shape (n_samples,). + + Returns + ------- + self : object + Returns self. + + """ + return self diff --git a/tests/test_estimators/test_samplers/test_stochastic_sampler.py b/tests/test_estimators/test_samplers/test_stochastic_sampler.py index 2a7c8e19..cd9771e0 100644 --- a/tests/test_estimators/test_samplers/test_stochastic_sampler.py +++ b/tests/test_estimators/test_samplers/test_stochastic_sampler.py @@ -1 +1,460 @@ """Tests for StochasticSampler.""" + +import unittest +from typing import Any +from unittest import mock + +import numpy as np +import numpy.typing as npt + +from molpipeline.estimators.samplers.stochastic_filter import StochasticFilter +from molpipeline.estimators.samplers.stochastic_sampler import StochasticSampler + + +class _NoFilter(StochasticFilter): # pylint: disable=too-few-public-methods + """Mock filter that returns predefined probabilities.""" + + def __init__(self, probabilities: npt.NDArray[np.float64]) -> None: + """Create a new _NoFilter. + + Parameters + ---------- + probabilities : npt.NDArray[np.float64] + Predefined probabilities for the samples. + + """ + self.probabilities = probabilities + + def calculate_probabilities( + self, + _X: npt.NDArray[np.float64], # noqa: N803 # pylint: disable=invalid-name + _y: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + """Calculate probabilities for the samples. + + Parameters + ---------- + X : npt.NDArray[np.float64] + Training data. Will be ignored, only present because of compatibility. + y : npt.NDArray[np.float64] + Target values. Will be ignored, only present because of compatibility. + + Returns + ------- + npt.NDArray[np.float64] + Predefined probabilities for the samples. + + """ + return self.probabilities + + +class TestStochasticSampler(unittest.TestCase): + """Test the StochasticSampler class.""" + + def setUp(self) -> None: + """Set up test data.""" + self.x_matrix = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + self.y = np.array([0, 1, 0, 1]) + self.simple_filter = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + self.uniform_filter = _NoFilter(np.array([0.25, 0.25, 0.25, 0.25])) + + def test_init_with_valid_parameters(self) -> None: + """Test initialization with valid parameters.""" + sampler = StochasticSampler( + filters=[self.simple_filter], + n_samples=2, + random_state=42, + ) + self.assertEqual(sampler.n_samples, 2) + self.assertEqual(sampler.combination_method, "product") + self.assertEqual(len(sampler.filters), 1) + self.assertIsInstance( + sampler.rng, np.random.RandomState # pylint: disable=no-member + ) + + def test_init_with_invalid_n_samples(self) -> None: + """Test initialization with invalid n_samples.""" + with self.assertRaises(ValueError): + StochasticSampler(filters=[self.simple_filter], n_samples=0) + + with self.assertRaises(ValueError): + StochasticSampler(filters=[self.simple_filter], n_samples=-1) + + def test_init_with_invalid_combination_method(self) -> None: + """Test initialization with invalid combination method.""" + with self.assertRaises(ValueError): + StochasticSampler( + filters=[self.simple_filter], + n_samples=2, + combination_method="invalid", # type: ignore[arg-type] + ) + + def test_fit_returns_self(self) -> None: + """Test that fit returns self.""" + sampler = StochasticSampler(filters=[self.simple_filter], n_samples=2) + result = sampler.fit(self.x_matrix, self.y) + self.assertIs(result, sampler) + + def test_calculate_probabilities_single_filter(self) -> None: + """Test calculate_probabilities with a single filter.""" + sampler = StochasticSampler( + filters=[self.simple_filter], + n_samples=2, + random_state=42, + ) + + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Should return normalized probabilities from the filter + expected = np.array([0.1, 0.2, 0.3, 0.4]) / np.array([0.1, 0.2, 0.3, 0.4]).sum() + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_calculate_probabilities_multiple_filters_product(self) -> None: + """Test calculate_probabilities w/ multiple filters and product combination.""" + filter1 = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + filter2 = _NoFilter(np.array([0.4, 0.3, 0.2, 0.1])) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=2, + combination_method="product", + random_state=42, + ) + + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Expected: product of probabilities, normalized + expected_unnormalized = np.array([0.1 * 0.4, 0.2 * 0.3, 0.3 * 0.2, 0.4 * 0.1]) + expected = expected_unnormalized / expected_unnormalized.sum() + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_calculate_probabilities_multiple_filters_mean(self) -> None: + """Test calculate_probabilities with multiple filters using mean combination.""" + filter1 = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + filter2 = _NoFilter(np.array([0.4, 0.3, 0.2, 0.1])) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=2, + combination_method="mean", + random_state=42, + ) + + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Expected: mean of probabilities, normalized + expected_unnormalized = np.array( + [(0.1 + 0.4) / 2, (0.2 + 0.3) / 2, (0.3 + 0.2) / 2, (0.4 + 0.1) / 2], + ) + expected = expected_unnormalized / expected_unnormalized.sum() + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_calculate_probabilities_with_all_zero_probabilities(self) -> None: + """Test calculate_probabilities with all zero probabilities.""" + zero_filter = _NoFilter(np.array([0.0, 0.0, 0.0, 0.0])) + + sampler = StochasticSampler(filters=[zero_filter], n_samples=10) + + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Should return uniform distribution when all probabilities are zero + expected = np.ones(4) / 4 + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_calculate_probabilities_with_zero_probabilities_multiple_filters( + self, + ) -> None: + """Test numerical stability in combine_probabilities with zero values.""" + filter1 = _NoFilter(np.array([0.0, 0.3, 0.5, 0.0])) + filter2 = _NoFilter(np.array([0.2, 0.0, 0.6, 0.0])) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=2, + combination_method="product", + ) + + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + expected_unnormalized = np.array([0.0, 0.0, 0.5 * 0.6, 0.0]) + expected = expected_unnormalized / expected_unnormalized.sum() + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + # filters with all zero probabilities in product combination + filter1 = _NoFilter(np.array([0.0, 0.3, 0.0, 0.0])) + filter2 = _NoFilter(np.array([0.2, 0.0, 0.6, 0.0])) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=2, + combination_method="product", + ) + + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Should return uniform distribution when all probabilities are zero + expected = np.ones(4) / 4 + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_transform_small_probabilities_product(self) -> None: + """Test numerical stability with very small probabilities.""" + # Create a filter with very small probabilities + tiny_probs = np.array([1e-10, 1e-15, 1e-20, 1e-5]) + filter1 = _NoFilter(tiny_probs) + filter2 = _NoFilter(tiny_probs) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=10, + random_state=42, + combination_method="product", + ) + + # This should not raise numerical errors + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Expected: product of probabilities, normalized + expected_unnormalized = tiny_probs * tiny_probs + expected = expected_unnormalized / expected_unnormalized.sum() + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_transform_small_probabilities_mean(self) -> None: + """Test numerical stability with very small probabilities.""" + # Create a filter with very small probabilities + tiny_probs = np.array([1e-10, 1e-15, 1e-20, 1e-5]) + filter1 = _NoFilter(tiny_probs) + filter2 = _NoFilter(tiny_probs) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=10, + random_state=42, + combination_method="mean", + ) + + # This should not raise numerical errors + probs = sampler.calculate_probabilities(self.x_matrix, self.y) + + # Expected: product of probabilities, normalized + expected_unnormalized = (tiny_probs + tiny_probs) / 2 + expected = expected_unnormalized / expected_unnormalized.sum() + self.assertTrue(np.allclose(probs, expected)) + self.assertAlmostEqual(probs.sum(), 1.0) + + def test_calculate_probabilities_calls_filters(self) -> None: + """Test that each filter's calculate_probabilities method is called.""" + mock_filter1 = mock.MagicMock(spec=StochasticFilter) + mock_filter1.calculate_probabilities.return_value = np.array( + [0.1, 0.2, 0.3, 0.4], + ) + + mock_filter2 = mock.MagicMock(spec=StochasticFilter) + mock_filter2.calculate_probabilities.return_value = np.array( + [0.4, 0.3, 0.2, 0.1], + ) + + sampler = StochasticSampler(filters=[mock_filter1, mock_filter2], n_samples=2) + + sampler.calculate_probabilities(self.x_matrix, self.y) + + mock_filter1.calculate_probabilities.assert_called_once() + mock_filter2.calculate_probabilities.assert_called_once() + + # Check that args passed to filter were correct + args1, _ = mock_filter1.calculate_probabilities.call_args + args2, _ = mock_filter2.calculate_probabilities.call_args + self.assertTrue(np.array_equal(args1[0], self.x_matrix)) + self.assertTrue(np.array_equal(args1[1], self.y)) + self.assertTrue(np.array_equal(args2[0], self.x_matrix)) + self.assertTrue(np.array_equal(args2[1], self.y)) + + @staticmethod + def _generate_observed_probs( + n_runs: int, + n_samples: int, + sampler_kwargs: dict[str, Any], + ) -> npt.NDArray[np.float64]: + """Generate expected probabilities for the inverse_probability_sampling test. + + Parameters + ---------- + n_runs : int + Number of runs to average the probabilities. + n_samples : int + The number of samples to generate. + sampler_kwargs : dict[str, Any] + Key word arguments for the sampler. + + Returns + ------- + npt.NDArray[np.float64] + Observed sampling probabilities for each group. + + """ + # x_matrix data will contain 50 samples with identifiers as values + x_matrix: npt.NDArray[np.float64] = np.arange( + n_samples, + dtype=np.float64, + ).reshape(-1, 1) + y = np.ones(len(x_matrix), dtype=np.float64) + + sampled_counts = np.zeros_like(y) + rng = np.random.default_rng(1234) + for _ in range(n_runs): + sampler = StochasticSampler( + random_state=rng.integers(0, np.iinfo(np.int32).max), + **sampler_kwargs, + ) + x_sampled, _ = sampler.transform( + x_matrix, + y, + ) + + # Count newly added samples by group + unique_samples, counts = np.unique(x_sampled, return_counts=True) + + for sample_id, count in zip(unique_samples, counts, strict=True): + sampled_counts[int(sample_id)] += count + + # Calculate and return observed sampling probabilities + return sampled_counts / sampled_counts.sum() + + def test_transform_with_multiple_filters_product(self) -> None: + """Test transform with multiple filters using product combination.""" + filter1 = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + filter2 = _NoFilter(np.array([0.4, 0.3, 0.2, 0.1])) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=10, + combination_method="product", + random_state=42, + ) + x_sampled, y_sampled = sampler.transform(self.x_matrix, self.y) + + self.assertEqual(x_sampled.shape, (10, 2)) + self.assertEqual(y_sampled.shape, (10,)) + + def test_transform_with_multiple_filters_product_observed_freqs(self) -> None: + """Test observed sample frequencies match expected frequencies.""" + filter1 = _NoFilter(np.array([0.5, 0.25, 0.25, 0.25])) + filter2 = _NoFilter(np.array([0.5, 0.25, 0.25, 0.25])) + + sampler_kwargs = { + "filters": [filter1, filter2], + "n_samples": 100, + "combination_method": "product", + } + + observed_probs = self._generate_observed_probs( + n_runs=1000, + n_samples=4, + sampler_kwargs=sampler_kwargs, + ) + + # the first element of the expected probs is biased by 2 + expected_probs = np.array([0.5 * 0.5, *([0.25 * 0.25] * 3)]) + expected_probs /= expected_probs.sum() + self.assertTrue(np.allclose(expected_probs, observed_probs, rtol=0, atol=1e-2)) + + def test_transform_with_multiple_filters_mean(self) -> None: + """Test transform with multiple filters using mean combination.""" + filter1 = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + filter2 = _NoFilter(np.array([0.4, 0.3, 0.2, 0.1])) + + sampler = StochasticSampler( + filters=[filter1, filter2], + n_samples=10, + combination_method="mean", + random_state=42, + ) + + x_sampled, y_sampled = sampler.transform(self.x_matrix, self.y) + + self.assertEqual(x_sampled.shape, (10, 2)) + self.assertEqual(y_sampled.shape, (10,)) + + def test_transform_with_multiple_filters_mean_observed_freqs(self) -> None: + """Test observed sample frequencies match expected frequencies.""" + filter1 = _NoFilter(np.array([0.5, 0.25, 0.25, 0.25])) + filter2 = _NoFilter(np.array([0.5, 0.25, 0.25, 0.25])) + + sampler_kwargs = { + "filters": [filter1, filter2], + "n_samples": 10, + "combination_method": "mean", + } + + observed_probs = self._generate_observed_probs( + n_runs=1000, + n_samples=4, + sampler_kwargs=sampler_kwargs, + ) + + # the first element of the expected probs is biased by 2 + expected_probs = np.array([(0.5 + 0.5) / 2, *([(0.25 + 0.25) / 2] * 3)]) + expected_probs /= expected_probs.sum() + self.assertTrue(np.allclose(expected_probs, observed_probs, rtol=0, atol=1e-2)) + + def test_transform_with_zero_probabilities(self) -> None: + """Test transform with all zero probabilities uses uniform distribution.""" + zero_filter = _NoFilter(np.array([0.0, 0.0, 0.0, 0.0])) + + sampler = StochasticSampler( + filters=[zero_filter], + n_samples=100, + random_state=42, + ) + + # This should now work without raising an error + x_sampled, y_sampled = sampler.transform(self.x_matrix, self.y) + + self.assertEqual(x_sampled.shape, (100, 2)) + self.assertEqual(y_sampled.shape, (100,)) + + def test_transform_with_zero_probabilities_product_observed_freqs(self) -> None: + """Test observed sample frequencies match expected frequencies.""" + filter1 = _NoFilter(np.array([0.0, 0.0, 0.0, 0.0])) + filter2 = _NoFilter(np.array([0.0, 0.0, 0.0, 0.0])) + + sampler_kwargs = { + "filters": [filter1, filter2], + "n_samples": 10, + "combination_method": "product", + } + + observed_probs = self._generate_observed_probs( + n_runs=1000, + n_samples=4, + sampler_kwargs=sampler_kwargs, + ) + + # the first element of the expected probs is biased by 2 + expected_probs = np.ones(4) / 4 + self.assertTrue(np.allclose(expected_probs, observed_probs, rtol=0, atol=1e-2)) + + def test_reproducibility_with_fixed_seed(self) -> None: + """Test that results are reproducible with fixed random state.""" + sampler1 = StochasticSampler( + filters=[self.simple_filter], + n_samples=5, + random_state=42, + ) + sampler2 = StochasticSampler( + filters=[self.simple_filter], + n_samples=5, + random_state=42, + ) + + x_sampled1, y_sampled1 = sampler1.transform(self.x_matrix, self.y) + x_sampled2, y_sampled2 = sampler2.transform(self.x_matrix, self.y) + + self.assertTrue(np.array_equal(x_sampled1, x_sampled2)) + self.assertTrue(np.array_equal(y_sampled1, y_sampled2)) From 7aa9b90ae1e3ad1e958856b123899c1a09ca4fe2 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Fri, 16 May 2025 15:05:17 +0200 Subject: [PATCH 5/8] ruff format and docsig complaint --- molpipeline/estimators/samplers/stochastic_sampler.py | 4 ++-- tests/test_estimators/test_samplers/test_oversampling.py | 4 +++- .../test_samplers/test_stochastic_sampler.py | 7 ++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/molpipeline/estimators/samplers/stochastic_sampler.py b/molpipeline/estimators/samplers/stochastic_sampler.py index 2b7661d5..335c655c 100644 --- a/molpipeline/estimators/samplers/stochastic_sampler.py +++ b/molpipeline/estimators/samplers/stochastic_sampler.py @@ -171,9 +171,9 @@ def fit( Parameters ---------- - X : npt.NDArray[np.float64] + _X : npt.NDArray[np.float64] The input samples of shape (n_samples, n_features). - y : npt.NDArray[np.float64] + _y : npt.NDArray[np.float64] The target values of shape (n_samples,). Returns diff --git a/tests/test_estimators/test_samplers/test_oversampling.py b/tests/test_estimators/test_samplers/test_oversampling.py index 1578675c..785767fc 100644 --- a/tests/test_estimators/test_samplers/test_oversampling.py +++ b/tests/test_estimators/test_samplers/test_oversampling.py @@ -72,7 +72,9 @@ def test_oversampling(self) -> None: x_resampled, y_resampled, ) = sampler.transform( # type: ignore[misc] - x_matrix, y, groups, + x_matrix, + y, + groups, ) self._assert_resampling_results(x_matrix, y, x_resampled, y_resampled) diff --git a/tests/test_estimators/test_samplers/test_stochastic_sampler.py b/tests/test_estimators/test_samplers/test_stochastic_sampler.py index cd9771e0..6c9cf0ef 100644 --- a/tests/test_estimators/test_samplers/test_stochastic_sampler.py +++ b/tests/test_estimators/test_samplers/test_stochastic_sampler.py @@ -34,9 +34,9 @@ def calculate_probabilities( Parameters ---------- - X : npt.NDArray[np.float64] + _X : npt.NDArray[np.float64] Training data. Will be ignored, only present because of compatibility. - y : npt.NDArray[np.float64] + _y : npt.NDArray[np.float64] Target values. Will be ignored, only present because of compatibility. Returns @@ -69,7 +69,8 @@ def test_init_with_valid_parameters(self) -> None: self.assertEqual(sampler.combination_method, "product") self.assertEqual(len(sampler.filters), 1) self.assertIsInstance( - sampler.rng, np.random.RandomState # pylint: disable=no-member + sampler.rng, + np.random.RandomState, # pylint: disable=no-member ) def test_init_with_invalid_n_samples(self) -> None: From 905819ff4e0d749016fcba1386ed4d81725a4396 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Fri, 16 May 2025 15:30:27 +0200 Subject: [PATCH 6/8] add sampler classes to _all__ and rename rng to random_state --- molpipeline/estimators/samplers/__init__.py | 16 ++++++++++++++++ .../estimators/samplers/stochastic_sampler.py | 4 ++-- .../test_samplers/test_stochastic_sampler.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/molpipeline/estimators/samplers/__init__.py b/molpipeline/estimators/samplers/__init__.py index 8cd7b5d0..5e265a0b 100644 --- a/molpipeline/estimators/samplers/__init__.py +++ b/molpipeline/estimators/samplers/__init__.py @@ -1 +1,17 @@ """Samplers to sample data with different strategies.""" + +__all__ = [ + "GlobalClassBalanceFilter", + "GroupSizeFilter", + "GroupRandomOversampler", + "LocalGroupClassBalanceFilter", + "StochasticSampler", +] + +from molpipeline.estimators.samplers.oversampling import GroupRandomOversampler +from molpipeline.estimators.samplers.stochastic_filter import ( + GlobalClassBalanceFilter, + GroupSizeFilter, + LocalGroupClassBalanceFilter, +) +from molpipeline.estimators.samplers.stochastic_sampler import StochasticSampler diff --git a/molpipeline/estimators/samplers/stochastic_sampler.py b/molpipeline/estimators/samplers/stochastic_sampler.py index 335c655c..91e15813 100644 --- a/molpipeline/estimators/samplers/stochastic_sampler.py +++ b/molpipeline/estimators/samplers/stochastic_sampler.py @@ -54,7 +54,7 @@ def __init__( self.filters = filters self.n_samples = n_samples self.combination_method = combination_method - self.rng = check_random_state(random_state) + self.random_state = check_random_state(random_state) def _combine_probabilities( self, @@ -152,7 +152,7 @@ def transform( combined_probabilities = self.calculate_probabilities(x_matrix, y) # Sample indices based on combined probabilities - indices = self.rng.choice( + indices = self.random_state.choice( len(x_matrix), size=self.n_samples, replace=True, diff --git a/tests/test_estimators/test_samplers/test_stochastic_sampler.py b/tests/test_estimators/test_samplers/test_stochastic_sampler.py index 6c9cf0ef..23398028 100644 --- a/tests/test_estimators/test_samplers/test_stochastic_sampler.py +++ b/tests/test_estimators/test_samplers/test_stochastic_sampler.py @@ -69,7 +69,7 @@ def test_init_with_valid_parameters(self) -> None: self.assertEqual(sampler.combination_method, "product") self.assertEqual(len(sampler.filters), 1) self.assertIsInstance( - sampler.rng, + sampler.random_state, np.random.RandomState, # pylint: disable=no-member ) From 485f37d691a5489061de954436882b701926376e Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Fri, 16 May 2025 15:33:53 +0200 Subject: [PATCH 7/8] fix __all__ sorting --- molpipeline/estimators/samplers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/estimators/samplers/__init__.py b/molpipeline/estimators/samplers/__init__.py index 5e265a0b..e9975e14 100644 --- a/molpipeline/estimators/samplers/__init__.py +++ b/molpipeline/estimators/samplers/__init__.py @@ -2,8 +2,8 @@ __all__ = [ "GlobalClassBalanceFilter", - "GroupSizeFilter", "GroupRandomOversampler", + "GroupSizeFilter", "LocalGroupClassBalanceFilter", "StochasticSampler", ] From b7b04dafedd536839b79fa4721374bfea224855d Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Mon, 19 May 2025 11:01:31 +0200 Subject: [PATCH 8/8] user input mitigation and linting fixes --- .../estimators/samplers/stochastic_filter.py | 6 +- .../estimators/samplers/stochastic_sampler.py | 2 +- .../test_samplers/test_stochastic_sampler.py | 188 +++++++++++++----- 3 files changed, 145 insertions(+), 51 deletions(-) diff --git a/molpipeline/estimators/samplers/stochastic_filter.py b/molpipeline/estimators/samplers/stochastic_filter.py index 68a79f5c..a2cc7741 100644 --- a/molpipeline/estimators/samplers/stochastic_filter.py +++ b/molpipeline/estimators/samplers/stochastic_filter.py @@ -1,7 +1,11 @@ """Stochastic filters for providing filter probabilities for sampling.""" from abc import ABC, abstractmethod -from typing import override + +try: + from typing import override # type: ignore[attr-defined] +except ImportError: + from typing_extensions import override import numpy as np import numpy.typing as npt diff --git a/molpipeline/estimators/samplers/stochastic_sampler.py b/molpipeline/estimators/samplers/stochastic_sampler.py index 91e15813..77c53f33 100644 --- a/molpipeline/estimators/samplers/stochastic_sampler.py +++ b/molpipeline/estimators/samplers/stochastic_sampler.py @@ -52,7 +52,7 @@ def __init__( "combination_method must be either 'product' or 'mean'", ) self.filters = filters - self.n_samples = n_samples + self.n_samples = int(n_samples) self.combination_method = combination_method self.random_state = check_random_state(random_state) diff --git a/tests/test_estimators/test_samplers/test_stochastic_sampler.py b/tests/test_estimators/test_samplers/test_stochastic_sampler.py index 23398028..fcd11576 100644 --- a/tests/test_estimators/test_samplers/test_stochastic_sampler.py +++ b/tests/test_estimators/test_samplers/test_stochastic_sampler.py @@ -7,7 +7,12 @@ import numpy as np import numpy.typing as npt -from molpipeline.estimators.samplers.stochastic_filter import StochasticFilter +from molpipeline.estimators.samplers.stochastic_filter import ( + GlobalClassBalanceFilter, + GroupSizeFilter, + LocalGroupClassBalanceFilter, + StochasticFilter, +) from molpipeline.estimators.samplers.stochastic_sampler import StochasticSampler @@ -48,8 +53,8 @@ def calculate_probabilities( return self.probabilities -class TestStochasticSampler(unittest.TestCase): - """Test the StochasticSampler class.""" +class TestStochasticSamplerGeneral(unittest.TestCase): + """Run general tests for the StochasticSampler class.""" def setUp(self) -> None: """Set up test data.""" @@ -96,6 +101,99 @@ def test_fit_returns_self(self) -> None: result = sampler.fit(self.x_matrix, self.y) self.assertIs(result, sampler) + def test_reproducibility_with_fixed_seed(self) -> None: + """Test that results are reproducible with fixed random state.""" + sampler1 = StochasticSampler( + filters=[self.simple_filter], + n_samples=5, + random_state=42, + ) + sampler2 = StochasticSampler( + filters=[self.simple_filter], + n_samples=5, + random_state=42, + ) + + x_sampled1, y_sampled1 = sampler1.transform(self.x_matrix, self.y) + x_sampled2, y_sampled2 = sampler2.transform(self.x_matrix, self.y) + + self.assertTrue(np.array_equal(x_sampled1, x_sampled2)) + self.assertTrue(np.array_equal(y_sampled1, y_sampled2)) + + def test_n_samples_is_1(self) -> None: + """Test that correct behaviour when n_samples is 1.""" + # Create a sampler with n_samples=1 + sampler = StochasticSampler( + filters=[self.simple_filter], + n_samples=1, + random_state=42, + ) + + # Transform the data + x_sampled, y_sampled = sampler.transform(self.x_matrix, self.y) + + # Check that the output shape is correct (should be a single sample) + self.assertEqual(x_sampled.shape, (1, 2)) + self.assertEqual(y_sampled.shape, (1,)) + + def test_n_samples_is_float(self) -> None: + """Test that correct behaviour when n_samples is 1. + + This test addresses a bug when a user ignores the type hint and passes a + float value for n_samples. + """ + # test when n_samples is a float + sampler = StochasticSampler( + filters=[self.simple_filter], + n_samples=1.0, # type: ignore[arg-type] + random_state=42, + ) + + # Transform the data + x_sampled, y_sampled = sampler.transform(self.x_matrix, self.y) + + # Check that the output shape is correct (should be a single sample) + self.assertEqual(x_sampled.shape, (1, 2)) + self.assertEqual(y_sampled.shape, (1,)) + + def test_integration_with_filters(self) -> None: + """Simple test to check integration with a filters.""" + # global balance: class 0: 7, class 1: 12 + # group sizes: group 1: 3, group 2: 1, group 3: 1, group 4: 14 + y = np.array([0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + x_matrix = np.zeros((len(y), 2)) + groups = np.array([1, 1, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + + group_size_filter = GroupSizeFilter(groups=groups) + global_class_balance_filter = GlobalClassBalanceFilter() + local_group_class_balance_filter = LocalGroupClassBalanceFilter(groups=groups) + + sampler = StochasticSampler( + filters=[ + group_size_filter, + global_class_balance_filter, + local_group_class_balance_filter, + ], + n_samples=5, + random_state=42, + ) + + x_sampled, y_sampled = sampler.transform(x_matrix, y) + + self.assertEqual(x_sampled.shape, (5, 2)) + self.assertEqual(y_sampled.shape, (5,)) + + +class TestStochasticSamplerCalculateProbabilities(unittest.TestCase): + """Test the calculate_probabilities method of StochasticSampler.""" + + def setUp(self) -> None: + """Set up test data.""" + self.x_matrix = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + self.y = np.array([0, 1, 0, 1]) + self.simple_filter = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + self.uniform_filter = _NoFilter(np.array([0.25, 0.25, 0.25, 0.25])) + def test_calculate_probabilities_single_filter(self) -> None: """Test calculate_probabilities with a single filter.""" sampler = StochasticSampler( @@ -203,6 +301,44 @@ def test_calculate_probabilities_with_zero_probabilities_multiple_filters( self.assertTrue(np.allclose(probs, expected)) self.assertAlmostEqual(probs.sum(), 1.0) + def test_calculate_probabilities_calls_filters(self) -> None: + """Test that each filter's calculate_probabilities method is called.""" + mock_filter1 = mock.MagicMock(spec=StochasticFilter) + mock_filter1.calculate_probabilities.return_value = np.array( + [0.1, 0.2, 0.3, 0.4], + ) + + mock_filter2 = mock.MagicMock(spec=StochasticFilter) + mock_filter2.calculate_probabilities.return_value = np.array( + [0.4, 0.3, 0.2, 0.1], + ) + + sampler = StochasticSampler(filters=[mock_filter1, mock_filter2], n_samples=2) + + sampler.calculate_probabilities(self.x_matrix, self.y) + + mock_filter1.calculate_probabilities.assert_called_once() + mock_filter2.calculate_probabilities.assert_called_once() + + # Check that args passed to filter were correct + args1, _ = mock_filter1.calculate_probabilities.call_args + args2, _ = mock_filter2.calculate_probabilities.call_args + self.assertTrue(np.array_equal(args1[0], self.x_matrix)) + self.assertTrue(np.array_equal(args1[1], self.y)) + self.assertTrue(np.array_equal(args2[0], self.x_matrix)) + self.assertTrue(np.array_equal(args2[1], self.y)) + + +class TestStochasticSamplerTransform(unittest.TestCase): + """Test the transform method of StochasticSampler.""" + + def setUp(self) -> None: + """Set up test data.""" + self.x_matrix = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + self.y = np.array([0, 1, 0, 1]) + self.simple_filter = _NoFilter(np.array([0.1, 0.2, 0.3, 0.4])) + self.uniform_filter = _NoFilter(np.array([0.25, 0.25, 0.25, 0.25])) + def test_transform_small_probabilities_product(self) -> None: """Test numerical stability with very small probabilities.""" # Create a filter with very small probabilities @@ -249,33 +385,6 @@ def test_transform_small_probabilities_mean(self) -> None: self.assertTrue(np.allclose(probs, expected)) self.assertAlmostEqual(probs.sum(), 1.0) - def test_calculate_probabilities_calls_filters(self) -> None: - """Test that each filter's calculate_probabilities method is called.""" - mock_filter1 = mock.MagicMock(spec=StochasticFilter) - mock_filter1.calculate_probabilities.return_value = np.array( - [0.1, 0.2, 0.3, 0.4], - ) - - mock_filter2 = mock.MagicMock(spec=StochasticFilter) - mock_filter2.calculate_probabilities.return_value = np.array( - [0.4, 0.3, 0.2, 0.1], - ) - - sampler = StochasticSampler(filters=[mock_filter1, mock_filter2], n_samples=2) - - sampler.calculate_probabilities(self.x_matrix, self.y) - - mock_filter1.calculate_probabilities.assert_called_once() - mock_filter2.calculate_probabilities.assert_called_once() - - # Check that args passed to filter were correct - args1, _ = mock_filter1.calculate_probabilities.call_args - args2, _ = mock_filter2.calculate_probabilities.call_args - self.assertTrue(np.array_equal(args1[0], self.x_matrix)) - self.assertTrue(np.array_equal(args1[1], self.y)) - self.assertTrue(np.array_equal(args2[0], self.x_matrix)) - self.assertTrue(np.array_equal(args2[1], self.y)) - @staticmethod def _generate_observed_probs( n_runs: int, @@ -440,22 +549,3 @@ def test_transform_with_zero_probabilities_product_observed_freqs(self) -> None: # the first element of the expected probs is biased by 2 expected_probs = np.ones(4) / 4 self.assertTrue(np.allclose(expected_probs, observed_probs, rtol=0, atol=1e-2)) - - def test_reproducibility_with_fixed_seed(self) -> None: - """Test that results are reproducible with fixed random state.""" - sampler1 = StochasticSampler( - filters=[self.simple_filter], - n_samples=5, - random_state=42, - ) - sampler2 = StochasticSampler( - filters=[self.simple_filter], - n_samples=5, - random_state=42, - ) - - x_sampled1, y_sampled1 = sampler1.transform(self.x_matrix, self.y) - x_sampled2, y_sampled2 = sampler2.transform(self.x_matrix, self.y) - - self.assertTrue(np.array_equal(x_sampled1, x_sampled2)) - self.assertTrue(np.array_equal(y_sampled1, y_sampled2))