diff --git a/cdt/causality/pairwise/ANM.py b/cdt/causality/pairwise/ANM.py index 4b1903b..5b8e768 100644 --- a/cdt/causality/pairwise/ANM.py +++ b/cdt/causality/pairwise/ANM.py @@ -186,7 +186,7 @@ def anm_score(self, x, y): float: ANM fit score """ gp = GaussianProcessRegressor().fit(x, y) - y_predict = gp.predict(x) + y_predict = gp.predict(x).reshape(-1, 1) indepscore = normalized_hsic(y_predict - y, x) return indepscore