|
| 1 | +# Copyright 2024 MTS (Mobile Telesystems) |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""EASE model.""" |
| 16 | + |
| 17 | +import typing as tp |
| 18 | + |
| 19 | +import numpy as np |
| 20 | +from scipy import sparse |
| 21 | + |
| 22 | +from rectools import InternalIds |
| 23 | +from rectools.dataset import Dataset |
| 24 | + |
| 25 | +from .base import ModelBase, Scores |
| 26 | +from .rank import Distance, ImplicitRanker |
| 27 | + |
| 28 | + |
| 29 | +class EASEModel(ModelBase): |
| 30 | + """ |
| 31 | + Embarrassingly Shallow Autoencoders for Sparse Data model. |
| 32 | +
|
| 33 | + See https://arxiv.org/abs/1905.03375. |
| 34 | +
|
| 35 | + Please note that this algorithm requires a lot of RAM during `fit` method. |
| 36 | + Out-of-memory issues are possible for big datasets. |
| 37 | + Reasonable catalog size for local development is about 30k items. |
| 38 | + Reasonable amount of interactions is about 20m. |
| 39 | +
|
| 40 | + Parameters |
| 41 | + ---------- |
| 42 | + regularization : float |
| 43 | + The regularization factor of the weights. |
| 44 | + verbose : int, default 0 |
| 45 | + Degree of verbose output. If 0, no output will be provided. |
| 46 | + num_threads: int, default 1 |
| 47 | + Number of threads used for `recommend` method. |
| 48 | + """ |
| 49 | + |
| 50 | + u2i_dist = Distance.DOT |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + regularization: float = 500.0, |
| 55 | + num_threads: int = 1, |
| 56 | + verbose: int = 0, |
| 57 | + ): |
| 58 | + super().__init__(verbose=verbose) |
| 59 | + self.weight: np.ndarray |
| 60 | + self.regularization = regularization |
| 61 | + self.num_threads = num_threads |
| 62 | + |
| 63 | + def _fit(self, dataset: Dataset) -> None: # type: ignore |
| 64 | + ui_csr = dataset.get_user_item_matrix(include_weights=True) |
| 65 | + |
| 66 | + gram_matrix = ui_csr.T @ ui_csr |
| 67 | + gram_matrix += self.regularization * sparse.identity(gram_matrix.shape[0]).astype(np.float32) |
| 68 | + gram_matrix = gram_matrix.todense() |
| 69 | + |
| 70 | + gram_matrix_inv = np.linalg.inv(gram_matrix) |
| 71 | + |
| 72 | + self.weight = np.array(gram_matrix_inv / (-np.diag(gram_matrix_inv))) |
| 73 | + np.fill_diagonal(self.weight, 0.0) |
| 74 | + |
| 75 | + def _recommend_u2i( |
| 76 | + self, |
| 77 | + user_ids: np.ndarray, |
| 78 | + dataset: Dataset, |
| 79 | + k: int, |
| 80 | + filter_viewed: bool, |
| 81 | + sorted_item_ids_to_recommend: tp.Optional[np.ndarray], |
| 82 | + ) -> tp.Tuple[InternalIds, InternalIds, Scores]: |
| 83 | + user_items = dataset.get_user_item_matrix(include_weights=True) |
| 84 | + |
| 85 | + ranker = ImplicitRanker( |
| 86 | + distance=self.u2i_dist, |
| 87 | + subjects_factors=user_items, |
| 88 | + objects_factors=self.weight, |
| 89 | + ) |
| 90 | + ui_csr_for_filter = user_items[user_ids] if filter_viewed else None |
| 91 | + |
| 92 | + all_user_ids, all_reco_ids, all_scores = ranker.rank( |
| 93 | + subject_ids=user_ids, |
| 94 | + k=k, |
| 95 | + filter_pairs_csr=ui_csr_for_filter, |
| 96 | + sorted_object_whitelist=sorted_item_ids_to_recommend, |
| 97 | + num_threads=self.num_threads, |
| 98 | + ) |
| 99 | + |
| 100 | + return all_user_ids, all_reco_ids, all_scores |
| 101 | + |
| 102 | + def _recommend_i2i( |
| 103 | + self, |
| 104 | + target_ids: np.ndarray, |
| 105 | + dataset: Dataset, |
| 106 | + k: int, |
| 107 | + sorted_item_ids_to_recommend: tp.Optional[np.ndarray], |
| 108 | + ) -> tp.Tuple[InternalIds, InternalIds, Scores]: |
| 109 | + similarity = self.weight[target_ids] |
| 110 | + if sorted_item_ids_to_recommend is not None: |
| 111 | + similarity = similarity[:, sorted_item_ids_to_recommend] |
| 112 | + |
| 113 | + n_reco = min(k, similarity.shape[1]) |
| 114 | + unsorted_reco_positions = similarity.argpartition(-n_reco, axis=1)[:, -n_reco:] |
| 115 | + unsorted_reco_scores = np.take_along_axis(similarity, unsorted_reco_positions, axis=1) |
| 116 | + |
| 117 | + sorted_reco_positions = unsorted_reco_scores.argsort()[:, ::-1] |
| 118 | + |
| 119 | + all_reco_scores = np.take_along_axis(unsorted_reco_scores, sorted_reco_positions, axis=1) |
| 120 | + all_reco_ids = np.take_along_axis(unsorted_reco_positions, sorted_reco_positions, axis=1) |
| 121 | + |
| 122 | + all_target_ids = np.repeat(target_ids, n_reco) |
| 123 | + |
| 124 | + if sorted_item_ids_to_recommend is not None: |
| 125 | + all_reco_ids = sorted_item_ids_to_recommend[all_reco_ids] |
| 126 | + |
| 127 | + return all_target_ids, all_reco_ids.flatten(), all_reco_scores.flatten() |
0 commit comments