Skip to content

Commit f7aee48

Browse files
authored
Feature/ease (#107)
Added [`EASE`](https://arxiv.org/pdf/1905.03375.pdf) model (#91)
1 parent 2676798 commit f7aee48

File tree

12 files changed

+747
-302
lines changed

12 files changed

+747
-302
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Methods for conversion `Interactions` to raw form and for getting raw interactions from `Dataset` ([#69](https://github.com/MobileTeleSystems/RecTools/pull/69))
1414
- `AvgRecPopularity (Average Recommendation Popularity)` to `metrics` ([#81](https://github.com/MobileTeleSystems/RecTools/pull/81))
1515
- Added `normalized` parameter to `AvgRecPopularity` metric ([#89](https://github.com/MobileTeleSystems/RecTools/pull/89))
16+
- Added `EASE` model ([#107](https://github.com/MobileTeleSystems/RecTools/pull/107))
1617

1718
### Changed
1819
- Loosened `pandas`, `torch` and `torch-light` versions for `python >= 3.8` ([#58](https://github.com/MobileTeleSystems/RecTools/pull/58))

rectools/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Models
2727
------
2828
`models.DSSMModel`
29+
`models.EASEModel`
2930
`models.ImplicitALSWrapperModel`
3031
`models.ImplicitItemKNNWrapperModel`
3132
`models.LightFMWrapperModel`
@@ -35,6 +36,7 @@
3536
`models.RandomModel`
3637
"""
3738

39+
from .ease import EASEModel
3840
from .implicit_als import ImplicitALSWrapperModel
3941
from .implicit_knn import ImplicitItemKNNWrapperModel
4042
from .popular import PopularModel
@@ -54,6 +56,7 @@
5456

5557

5658
__all__ = (
59+
"EASEModel",
5760
"ImplicitALSWrapperModel",
5861
"ImplicitItemKNNWrapperModel",
5962
"LightFMWrapperModel",

rectools/models/dssm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
from ..dataset.dataset import Dataset
3939
from ..dataset.torch_datasets import ItemFeaturesDataset, UserFeaturesDataset
4040
from ..exceptions import NotFittedError
41-
from .vector import Distance, Factors, VectorModel
41+
from .rank import Distance
42+
from .vector import Factors, VectorModel
4243

4344

4445
class ItemNet(nn.Module):

rectools/models/ease.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2024 MTS (Mobile Telesystems)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""EASE model."""
16+
17+
import typing as tp
18+
19+
import numpy as np
20+
from scipy import sparse
21+
22+
from rectools import InternalIds
23+
from rectools.dataset import Dataset
24+
25+
from .base import ModelBase, Scores
26+
from .rank import Distance, ImplicitRanker
27+
28+
29+
class EASEModel(ModelBase):
30+
"""
31+
Embarrassingly Shallow Autoencoders for Sparse Data model.
32+
33+
See https://arxiv.org/abs/1905.03375.
34+
35+
Please note that this algorithm requires a lot of RAM during `fit` method.
36+
Out-of-memory issues are possible for big datasets.
37+
Reasonable catalog size for local development is about 30k items.
38+
Reasonable amount of interactions is about 20m.
39+
40+
Parameters
41+
----------
42+
regularization : float
43+
The regularization factor of the weights.
44+
verbose : int, default 0
45+
Degree of verbose output. If 0, no output will be provided.
46+
num_threads: int, default 1
47+
Number of threads used for `recommend` method.
48+
"""
49+
50+
u2i_dist = Distance.DOT
51+
52+
def __init__(
53+
self,
54+
regularization: float = 500.0,
55+
num_threads: int = 1,
56+
verbose: int = 0,
57+
):
58+
super().__init__(verbose=verbose)
59+
self.weight: np.ndarray
60+
self.regularization = regularization
61+
self.num_threads = num_threads
62+
63+
def _fit(self, dataset: Dataset) -> None: # type: ignore
64+
ui_csr = dataset.get_user_item_matrix(include_weights=True)
65+
66+
gram_matrix = ui_csr.T @ ui_csr
67+
gram_matrix += self.regularization * sparse.identity(gram_matrix.shape[0]).astype(np.float32)
68+
gram_matrix = gram_matrix.todense()
69+
70+
gram_matrix_inv = np.linalg.inv(gram_matrix)
71+
72+
self.weight = np.array(gram_matrix_inv / (-np.diag(gram_matrix_inv)))
73+
np.fill_diagonal(self.weight, 0.0)
74+
75+
def _recommend_u2i(
76+
self,
77+
user_ids: np.ndarray,
78+
dataset: Dataset,
79+
k: int,
80+
filter_viewed: bool,
81+
sorted_item_ids_to_recommend: tp.Optional[np.ndarray],
82+
) -> tp.Tuple[InternalIds, InternalIds, Scores]:
83+
user_items = dataset.get_user_item_matrix(include_weights=True)
84+
85+
ranker = ImplicitRanker(
86+
distance=self.u2i_dist,
87+
subjects_factors=user_items,
88+
objects_factors=self.weight,
89+
)
90+
ui_csr_for_filter = user_items[user_ids] if filter_viewed else None
91+
92+
all_user_ids, all_reco_ids, all_scores = ranker.rank(
93+
subject_ids=user_ids,
94+
k=k,
95+
filter_pairs_csr=ui_csr_for_filter,
96+
sorted_object_whitelist=sorted_item_ids_to_recommend,
97+
num_threads=self.num_threads,
98+
)
99+
100+
return all_user_ids, all_reco_ids, all_scores
101+
102+
def _recommend_i2i(
103+
self,
104+
target_ids: np.ndarray,
105+
dataset: Dataset,
106+
k: int,
107+
sorted_item_ids_to_recommend: tp.Optional[np.ndarray],
108+
) -> tp.Tuple[InternalIds, InternalIds, Scores]:
109+
similarity = self.weight[target_ids]
110+
if sorted_item_ids_to_recommend is not None:
111+
similarity = similarity[:, sorted_item_ids_to_recommend]
112+
113+
n_reco = min(k, similarity.shape[1])
114+
unsorted_reco_positions = similarity.argpartition(-n_reco, axis=1)[:, -n_reco:]
115+
unsorted_reco_scores = np.take_along_axis(similarity, unsorted_reco_positions, axis=1)
116+
117+
sorted_reco_positions = unsorted_reco_scores.argsort()[:, ::-1]
118+
119+
all_reco_scores = np.take_along_axis(unsorted_reco_scores, sorted_reco_positions, axis=1)
120+
all_reco_ids = np.take_along_axis(unsorted_reco_positions, sorted_reco_positions, axis=1)
121+
122+
all_target_ids = np.repeat(target_ids, n_reco)
123+
124+
if sorted_item_ids_to_recommend is not None:
125+
all_reco_ids = sorted_item_ids_to_recommend[all_reco_ids]
126+
127+
return all_target_ids, all_reco_ids.flatten(), all_reco_scores.flatten()

rectools/models/implicit_als.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from rectools.dataset import Dataset, Features
2727
from rectools.exceptions import NotFittedError
2828

29-
from .vector import Distance, Factors, VectorModel
29+
from .rank import Distance
30+
from .vector import Factors, VectorModel
3031

3132
AVAILABLE_RECOMMEND_METHODS = ("loop",)
3233
AnyAlternatingLeastSquares = tp.Union[CPUAlternatingLeastSquares, GPUAlternatingLeastSquares]

rectools/models/lightfm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from rectools.dataset import Dataset, Features
2323
from rectools.exceptions import NotFittedError
2424

25-
from .vector import Distance, Factors, VectorModel
25+
from .rank import Distance
26+
from .vector import Factors, VectorModel
2627

2728

2829
class LightFMWrapperModel(VectorModel):

rectools/models/pure_svd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
from rectools.dataset import Dataset
2323
from rectools.exceptions import NotFittedError
24-
from rectools.models.vector import Distance, Factors, VectorModel
24+
from rectools.models.rank import Distance
25+
from rectools.models.vector import Factors, VectorModel
2526

2627

2728
class PureSVDModel(VectorModel):

0 commit comments

Comments
 (0)