Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
201 changes: 107 additions & 94 deletions README.md

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions build/lib/imputegap/algorithms/bayotide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import time

from imputegap.wrapper.AlgoPython.BayOTIDE.BayOTIDE import recoveryBayOTIDE


def bay_otide(incomp_data, K_trend=20, K_season=2, n_season=5, K_bias=1, time_scale=1, a0=0.6, b0=2.5, v=0.5, logs=True):
"""
BayOTIDE class to impute missing values using Bayesian Online Multivariate Time series Imputation with functional decomposition

Parameters
----------
incomp_data : numpy.ndarray
The input matrix with contamination (missing values represented as NaNs).

K_trend : int, (optional) (default: 20)
Number of trend factors.

K_season : int, (optional) (default: 2)
Number of seasonal factors.

n_season : int, (optional) (default: 5)
Number of seasonal components per factor.

K_bias : int, (optional) (default: 1)
Number of bias factors.

time_scale : float, (optional) (default: 1)
Time scaling factor.

a0 : float, (optional) (default: 0.6)
Hyperparameter for prior distribution.

b0 : float, (optional) (default: 2.5)
Hyperparameter for prior distribution.

v : float, (optional) (default: 0.5)
Variance parameter.

config : dict, (optional) (default: None)
Dictionary containing all configuration parameters, that will replace all other parameters (see documentation).

args : object, (optional) (default: None)
Arguments containing all configuration parameters, that will replace all other parameters (see documentation).

logs : bool, optional
Whether to log the execution time (default is True).

Returns
-------
numpy.ndarray
The imputed matrix with missing values recovered.

Example
-------
>>> recov_data = bay_otide(incomp_data, K_trend=20, K_season=2, n_season=5, K_bias=1, time_scale=1, a0=0.6, b0=2.5, v=0.5)
>>> print(recov_data)

References
----------
S. Fang, Q. Wen, Y. Luo, S. Zhe, and L. Sun, "BayOTIDE: Bayesian Online Multivariate Time Series Imputation with Functional Decomposition," CoRR, vol. abs/2308.14906, 2024. [Online]. Available: https://arxiv.org/abs/2308.14906.
https://github.com/xuangu-fang/BayOTIDE
"""
start_time = time.time() # Record start time

recov_data = recoveryBayOTIDE(data=incomp_data, K_trend=K_trend, K_season=K_season, n_season=n_season, K_bias=K_bias, time_scale=time_scale, a0=a0, b0=b0, v=v)

end_time = time.time()
if logs:
print(f"\n\t\t> logs, imputation bay_otide - Execution Time: {(end_time - start_time):.4f} seconds\n")

return recov_data
69 changes: 69 additions & 0 deletions build/lib/imputegap/algorithms/bit_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import time

from imputegap.wrapper.AlgoPython.BiTGraph.main import recoveryBitGRAPH


def bit_graph(incomp_data, node_number=-1, kernel_set=[1], dropout=0.1, subgraph_size=5, node_dim=3, seq_len=1, lr=0.001, epoch=10, seed=42, logs=True):
"""
Perform imputation using Recover From Blackouts in Tagged Time Series With Hankel Matrix Factorization

Parameters
----------
incomp_data : numpy.ndarray
The input matrix with contamination (missing values represented as NaNs).

node_number : int, optional
The number of nodes (time series variables) in the dataset. If not provided,
it is inferred from `incomp_data`. If -1, set automatically from the len of the values

kernel_set : list, optional
Set of kernel sizes used in the model for graph convolution operations (default: [1]).

dropout : float, optional
Dropout rate applied during training to prevent overfitting (default: 0.1).

subgraph_size : int, optional
The size of each subgraph used in message passing within the graph network (default: 5).

node_dim : int, optional
Dimensionality of the node embeddings in the graph convolution layers (default: 3).

seq_len : int, optional
Length of the input sequence for temporal modeling (default: 1).

lr : float, optional
Learning rate for model optimization (default: 0.001).

epoch : int, optional
Number of training epochs (default: 10).

seed : int, optional
Random seed for reproducibility (default: 42).

logs : bool, optional
Whether to log the execution time (default is True).

Returns
-------
numpy.ndarray
The imputed matrix with missing values recovered.

Example
-------
>>> recov_data = bit_graph(incomp_data, tags=None, data_names=None, epoch=10)
>>> print(recov_data)

References
----------
X. Chen1, X. Li, T. Wu, B. Liu and Z. Li, BIASED TEMPORAL CONVOLUTION GRAPH NETWORK FOR TIME SERIES FORECASTING WITH MISSING VALUES
https://github.com/chenxiaodanhit/BiTGraph
"""
start_time = time.time() # Record start time

recov_data = recoveryBitGRAPH(input=incomp_data, node_number=node_number, kernel_set=kernel_set, dropout=dropout, subgraph_size=subgraph_size, node_dim=node_dim, seq_len=seq_len, lr=lr, epoch=epoch, seed=seed)

end_time = time.time()
if logs:
print(f"\n\t\t> logs, imputation bit graph - Execution Time: {(end_time - start_time):.4f} seconds\n")

return recov_data
54 changes: 54 additions & 0 deletions build/lib/imputegap/algorithms/brits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import time
from imputegap.wrapper.AlgoPython.BRITS.runnerBRITS import brits_recovery


def brits(incomp_data, model="brits", epoch=10, batch_size=7, nbr_features=1, hidden_layers=64, seq_length=32, logs=True):
"""
Perform imputation using the BRITS algorithm.

Parameters
----------
incomp_data : numpy.ndarray
The input matrix with contamination (missing values represented as NaNs).
model : str
Specifies the type of model to use for the imputation. Options may include predefined models like 'brits', 'brits-i' or 'brits_i_univ'.
epoch : int
Number of epochs for training the model. Determines how many times the algorithm processes the entire dataset during training.
batch_size : int
Size of the batches used during training. Larger batch sizes can speed up training but may require more memory.
nbr_features : int
Number of features, dimension in the time series.
hidden_layers : int
Number of units in the hidden layer of the model. Controls the capacity of the neural network to learn complex patterns.
seq_length : int
Length of the input sequence used by the model. Defines the number of time steps processed at once.

Returns
-------
numpy.ndarray
The imputed matrix with missing values recovered.

Notes
-----
The BRITS algorithm is a machine learning-based approach for time series imputation, where missing values are recovered using a recurrent neural network structure.

This function logs the total execution time if `logs` is set to True.

Example
-------
>>> recov_data = brits(incomp_data=incomp_data, model="brits", epoch=10, batch_size=7, nbr_features=1, hidden_layers=64, seq_length=32, logs=True)
>>> print(recov_data)

References
----------
Cao, W., Wang, D., Li, J., Zhou, H., Li, L. & Li, Y. BRITS: Bidirectional Recurrent Imputation for Time Series. Advances in Neural Information Processing Systems, 31 (2018). https://proceedings.neurips.cc/paper_files/paper/2018/file/734e6bfcd358e25ac1db0a4241b95651-Paper.pdf
"""
start_time = time.time() # Record start time

recov_data = brits_recovery(incomp_data=incomp_data, model=model, epoch=epoch, batch_size=batch_size, nbr_features=nbr_features, hidden_layers=hidden_layers, seq_length=seq_length)

end_time = time.time()
if logs:
print(f"\n\t\t> logs, imputation brits - Execution Time: {(end_time - start_time):.4f} seconds\n")

return recov_data
53 changes: 6 additions & 47 deletions build/lib/imputegap/algorithms/cdrec.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,9 @@
import ctypes
import time
import ctypes as __native_c_types_import;
import numpy as __numpy_import;

from imputegap.tools import utils


def __marshal_as_numpy_column(__ctype_container, __py_sizen, __py_sizem):
"""
Marshal a ctypes container as a numpy column-major array.

Parameters
----------
__ctype_container : ctypes.Array
The input ctypes container (flattened matrix).
__py_sizen : int
The number of rows in the numpy array.
__py_sizem : int
The number of columns in the numpy array.

Returns
-------
numpy.ndarray
A numpy array reshaped to the original matrix dimensions (row-major order).
"""
__numpy_marshal = __numpy_import.array(__ctype_container).reshape(__py_sizem, __py_sizen).T;

return __numpy_marshal;


def __marshal_as_native_column(__py_matrix):
"""
Marshal a numpy array as a ctypes flat container for passing to native code.

Parameters
----------
__py_matrix : numpy.ndarray
The input numpy matrix (2D array).

Returns
-------
ctypes.Array
A ctypes array containing the flattened matrix (in column-major order).
"""
__py_input_flat = __numpy_import.ndarray.flatten(__py_matrix.T);
__ctype_marshal = __numpy_import.ctypeslib.as_ctypes(__py_input_flat);

return __ctype_marshal;


def native_cdrec(__py_matrix, __py_rank, __py_epsilon, __py_iterations):
"""
Perform matrix imputation using the CDRec algorithm with native C++ support.
Expand Down Expand Up @@ -92,11 +47,11 @@ def native_cdrec(__py_matrix, __py_rank, __py_epsilon, __py_iterations):
__ctype_iterations = __native_c_types_import.c_ulonglong(__py_iterations);

# Native code uses linear matrix layout, and also it's easier to pass it in like this
__ctype_matrix = __marshal_as_native_column(__py_matrix);
__ctype_matrix = utils.__marshal_as_native_column(__py_matrix);

shared_lib.cdrec_imputation_parametrized(__ctype_matrix, __ctype_size_n, __ctype_size_m, __ctype_rank, __ctype_epsilon, __ctype_iterations);

__py_imputed_matrix = __marshal_as_numpy_column(__ctype_matrix, __py_n, __py_m);
__py_imputed_matrix = utils.__marshal_as_numpy_column(__ctype_matrix, __py_n, __py_m);

return __py_imputed_matrix;

Expand Down Expand Up @@ -131,6 +86,10 @@ def cdrec(incomp_data, truncation_rank, iterations, epsilon, logs=True, lib_path
>>> print(recov_data)

"""

print(f"\t\t\t\t(PYTHON) CDRec: ({incomp_data.shape[0]},{incomp_data.shape[1]}) for rank {truncation_rank}, "
f"epsilon {epsilon}, and iterations {iterations}...")

start_time = time.time() # Record start time

# Call the C++ function to perform recovery
Expand Down
49 changes: 2 additions & 47 deletions build/lib/imputegap/algorithms/cpp_integration.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,12 @@
import ctypes
import time
import ctypes as __native_c_types_import;
import numpy as __numpy_import;

from imputegap.tools import utils

# =========================================================== #
# IN CASE OF NEED, YOU CAN ADAPT AND TAKE CDREC.PY AS A MODEL #
# =========================================================== #

def __marshal_as_numpy_column(__ctype_container, __py_sizen, __py_sizem):
"""
Marshal a ctypes container as a numpy column-major array.

Parameters
----------
__ctype_container : ctypes.Array
The input ctypes container (flattened matrix).
__py_sizen : int
The number of rows in the numpy array.
__py_sizem : int
The number of columns in the numpy array.

Returns
-------
numpy.ndarray
A numpy array reshaped to the original matrix dimensions (row-major order).
"""
__numpy_marshal = __numpy_import.array(__ctype_container).reshape(__py_sizem, __py_sizen).T;

return __numpy_marshal;


def __marshal_as_native_column(__py_matrix):
"""
Marshal a numpy array as a ctypes flat container for passing to native code.

Parameters
----------
__py_matrix : numpy.ndarray
The input numpy matrix (2D array).

Returns
-------
ctypes.Array
A ctypes array containing the flattened matrix (in column-major order).
"""
__py_input_flat = __numpy_import.ndarray.flatten(__py_matrix.T);
__ctype_marshal = __numpy_import.ctypeslib.as_ctypes(__py_input_flat);

return __ctype_marshal;


def native_algo(__py_matrix, __py_param):
"""
Perform matrix imputation using the CDRec algorithm with native C++ support.
Expand Down Expand Up @@ -85,13 +40,13 @@ def native_algo(__py_matrix, __py_param):
__py_param = __native_c_types_import.c_ulonglong(__py_param);

# Native code uses linear matrix layout, and also it's easier to pass it in like this
__ctype_matrix = __marshal_as_native_column(__py_matrix);
__ctype_matrix = utils.__marshal_as_native_column(__py_matrix);

# call your algorithm
shared_lib.your_algo_name(__ctype_matrix, __ctype_size_n, __ctype_size_m, __py_param);

# convert back to numpy
__py_imputed_matrix = __marshal_as_numpy_column(__ctype_matrix, __py_n, __py_m);
__py_imputed_matrix = utils.__marshal_as_numpy_column(__ctype_matrix, __py_n, __py_m);

return __py_imputed_matrix;

Expand Down
46 changes: 46 additions & 0 deletions build/lib/imputegap/algorithms/deep_mvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import time

from imputegap.wrapper.AlgoPython.DeepMVI.recoveryDeepMVI import deep_mvi_recovery


def deep_mvi(incomp_data, max_epoch=1000, patience=2, lr=0.001, logs=True):
"""
Perform imputation using the DEEP MVI (Deep Multivariate Imputation) algorithm.

Parameters
----------
incomp_data : numpy.ndarray
The input matrix with contamination (missing values represented as NaNs).
max_epoch : int, optional
Limit of training epoch (default is 1000)
patience : int, optional
Number of threshold error that can be crossed during the training (default is 2)
lr : float, optional
Learning rate of the training (default is 0.001)
logs : bool, optional
Whether to log the execution time (default is True).

Returns
-------
numpy.ndarray
The imputed matrix with missing values recovered.

Example
-------
>>> recov_data = deep_mvi(incomp_data, 1000, 2, 0.001)
>>> print(recov_data)

References
----------
P. Bansal, P. Deshpande, and S. Sarawagi. Missing value imputation on multidimensional time series. arXiv preprint arXiv:2103.01600, 2023
https://github.com/pbansal5/DeepMVI
"""
start_time = time.time() # Record start time

recov_data = deep_mvi_recovery(input=incomp_data, max_epoch=max_epoch, patience=patience, lr=lr)

end_time = time.time()
if logs:
print(f"\n\t\t> logs, imputation deep mvi - Execution Time: {(end_time - start_time):.4f} seconds\n")

return recov_data
Loading