diff --git a/fastFM/base.py b/fastFM/base.py index ae7604c..6173c4e 100644 --- a/fastFM/base.py +++ b/fastFM/base.py @@ -100,6 +100,27 @@ def predict(self, X_test): assert sp.isspmatrix_csc(X_test) assert X_test.shape[1] == len(self.w_) return ffm.ffm_predict(self.w0_, self.w_, self.V_, X_test) + + def fit_predict(self, X_train, y_train, X_test, n_more_iter=0): + """Return predictions after calling fit method + + Parameters + ---------- + X_train : scipy.sparse.csc_matrix, (n_samples, n_features) + + y_train : array, shape (n_samples) + + X_test : scipy.sparse.csc_matrix, (n_test_samples, n_features) + + n_more_iter : int + Number of iterations to continue from the current Coefficients. + + Returns + ------- + T : array, shape (n_test_samples) + """ + self.fit(X_train, y_train, n_more_iter=n_more_iter) + return self.predict(X_test) class BaseFMClassifier(FactorizationMachine, ClassifierMixin):