Skip to content

Commit 8eccee5

Browse files
authored
Feature/cold warm for vector models (#122)
- Added `dataset` argument to `recommend_for_cold` method - Supported warm and cold reco in `LightFMWrapperModel` - Supported warm reco in `DSSMModels` and in corresponding pytorch datasets - [Breaking] Replaced `include_warm` parameter in `Dataset.get_user_item_matrix` to pair `include_warm_users` and `include_warm_items` - [Breaking] Added `n_factors` and `deterministic` params to `DSSMModel` - [Breaking] Added default value for train dataset type and params for user and item dataset types in `DSSMModel` - [Breaking] Replaced `include_warm` parameter in `Dataset.get_user_item_matrix` to pair `include_warm_users` and `include_warm_items` - [Breaking] Renamed torch datasets and `dataset_type` to `train_dataset_type` param in `DSSMModel`
1 parent 5c3476d commit 8eccee5

20 files changed

+495
-229
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ valid-metaclass-classmethod-first-arg=cls
441441
[DESIGN]
442442

443443
# Maximum number of arguments for function / method.
444-
max-args=15
444+
max-args=20
445445

446446
# Maximum number of attributes for a class (see R0902).
447447
max-attributes=12

CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212
- Warm users/items support in `Dataset` ([#77](https://github.com/MobileTeleSystems/RecTools/pull/77))
13-
- Warm and cold users/items support in `ModelBase` and non-personalized models ([#77](https://github.com/MobileTeleSystems/RecTools/pull/77), [#120](https://github.com/MobileTeleSystems/RecTools/pull/120))
13+
- Warm and cold users/items support in `ModelBase` and all possible models ([#77](https://github.com/MobileTeleSystems/RecTools/pull/77), [#120](https://github.com/MobileTeleSystems/RecTools/pull/120), [#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
1414
- Warm and cold users/items support in `cross_validate` ([#77](https://github.com/MobileTeleSystems/RecTools/pull/77))
15+
- [Breaking] Default value for train dataset type and params for user and item dataset types in `DSSMModel` ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
16+
- [Breaking] `n_factors` and `deterministic` params to `DSSMModel` ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
1517

1618
### Changed
1719
- Changed the logic of choosing random sampler for `RandomModel` and increased the sampling speed ([#120](https://github.com/MobileTeleSystems/RecTools/pull/120))
18-
- Changed the logic of `RandomModel`: now the recommendations are different for repeated calls of recommend methods ([#120](https://github.com/MobileTeleSystems/RecTools/pull/120))
20+
- [Breaking] Changed the logic of `RandomModel`: now the recommendations are different for repeated calls of recommend methods ([#120](https://github.com/MobileTeleSystems/RecTools/pull/120))
21+
- Torch datasets to support warm recommendations ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
22+
- [Breaking] Replaced `include_warm` parameter in `Dataset.get_user_item_matrix` to pair `include_warm_users` and `include_warm_items` ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
23+
- [Breaking] Renamed torch datasets and `dataset_type` to `train_dataset_type` param in `DSSMModel` ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
1924

2025
### Removed
2126
- `return_external_ids` parameter in `recommend` and `recommend_to_items` model methods ([#77](https://github.com/MobileTeleSystems/RecTools/pull/77))

rectools/dataset/dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,12 @@ def _make_features(
194194
except Exception as e: # pragma: no cover
195195
raise RuntimeError(f"An error has occurred while constructing {feature_type} features: {e!r}")
196196

197-
def get_user_item_matrix(self, include_weights: bool = True, include_warm: bool = False) -> sparse.csr_matrix:
197+
def get_user_item_matrix(
198+
self,
199+
include_weights: bool = True,
200+
include_warm_users: bool = False,
201+
include_warm_items: bool = False,
202+
) -> sparse.csr_matrix:
198203
"""
199204
Construct user-item CSR matrix based on `interactions` attribute.
200205
@@ -219,8 +224,9 @@ def get_user_item_matrix(self, include_weights: bool = True, include_warm: bool
219224
Resized user-item CSR matrix
220225
"""
221226
matrix = self.interactions.get_user_item_matrix(include_weights)
222-
if include_warm:
223-
matrix.resize(self.user_id_map.size, self.item_id_map.size)
227+
n_rows = self.user_id_map.size if include_warm_users else matrix.shape[0]
228+
n_columns = self.item_id_map.size if include_warm_items else matrix.shape[1]
229+
matrix.resize(n_rows, n_columns)
224230
return matrix
225231

226232
def get_raw_interactions(self, include_weight: bool = True, include_datetime: bool = True) -> pd.DataFrame:

rectools/dataset/torch_datasets.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,23 @@
2525

2626
from .dataset import Dataset
2727

28-
DD = tp.TypeVar("DD", bound="DSSMDataset")
29-
ID = tp.TypeVar("ID", bound="ItemFeaturesDataset")
30-
UD = tp.TypeVar("UD", bound="UserFeaturesDataset")
28+
DSSMTrainDatasetT = tp.TypeVar("DSSMTrainDatasetT", bound="DSSMTrainDatasetBase")
29+
DSSMItemDatasetT = tp.TypeVar("DSSMItemDatasetT", bound="DSSMItemDatasetBase")
30+
DSSMUserDatasetT = tp.TypeVar("DSSMUserDatasetT", bound="DSSMUserDatasetBase")
3131

3232

33-
class DSSMDataset(TorchDataset[tp.Any]):
33+
class DSSMTrainDatasetBase(TorchDataset[tp.Any]):
34+
"""Base class for DSSM training datasets. Used only for type hinting."""
35+
36+
def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None:
37+
raise NotImplementedError()
38+
39+
@classmethod
40+
def from_dataset(cls: tp.Type[DSSMTrainDatasetT], dataset: Dataset) -> DSSMTrainDatasetT:
41+
raise NotImplementedError()
42+
43+
44+
class DSSMTrainDataset(DSSMTrainDatasetBase):
3445
"""
3546
Torch dataset wrapper for `rectools.dataset.dataset.Dataset`.
3647
Implements `torch.utils.data.Dataset` for subsequent usage with
@@ -68,7 +79,7 @@ def __init__(
6879
)
6980

7081
@classmethod
71-
def from_dataset(cls: tp.Type[DD], dataset: Dataset) -> DD:
82+
def from_dataset(cls: tp.Type[DSSMTrainDatasetT], dataset: Dataset) -> DSSMTrainDatasetT:
7283
ui_matrix = dataset.get_user_item_matrix()
7384

7485
# We take hot here since this dataset is used for fit only
@@ -99,7 +110,18 @@ def __getitem__(
99110
return user_features, interactions, pos, neg
100111

101112

102-
class ItemFeaturesDataset(TorchDataset[tp.Any]):
113+
class DSSMItemDatasetBase(TorchDataset[tp.Any]):
114+
"""Base class for DSSM item datasets. Used only for type hinting."""
115+
116+
def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None:
117+
raise NotImplementedError()
118+
119+
@classmethod
120+
def from_dataset(cls: tp.Type[DSSMItemDatasetT], dataset: Dataset) -> DSSMItemDatasetT:
121+
raise NotImplementedError()
122+
123+
124+
class DSSMItemDataset(DSSMItemDatasetBase):
103125
"""
104126
Torch dataset wrapper for `rectools.dataset.dataset.Dataset`.
105127
Implements `torch.utils.data.Dataset` for subsequent usage with
@@ -113,7 +135,7 @@ def __init__(self, items: sparse.csr_matrix):
113135
self.items = items
114136

115137
@classmethod
116-
def from_dataset(cls: tp.Type[ID], dataset: Dataset) -> ID:
138+
def from_dataset(cls: tp.Type[DSSMItemDatasetT], dataset: Dataset) -> DSSMItemDatasetT:
117139
# We take all features here since this dataset is used for recommend only, not for fit
118140
if dataset.item_features is not None:
119141
return cls(dataset.item_features.get_sparse())
@@ -126,7 +148,22 @@ def __getitem__(self, idx: int) -> torch.FloatTensor:
126148
return torch.FloatTensor(self.items[idx].toarray().flatten())
127149

128150

129-
class UserFeaturesDataset(TorchDataset[tp.Any]):
151+
class DSSMUserDatasetBase(TorchDataset[tp.Any]):
152+
"""Base class for DSSM training datasets. Used only for type hinting."""
153+
154+
def __init__(self, *args: tp.Any, **kwargs: tp.Any) -> None:
155+
raise NotImplementedError()
156+
157+
@classmethod
158+
def from_dataset(
159+
cls: tp.Type[DSSMUserDatasetT],
160+
dataset: Dataset,
161+
keep_users: tp.Optional[tp.Sequence[int]] = None,
162+
) -> DSSMUserDatasetT:
163+
raise NotImplementedError()
164+
165+
166+
class DSSMUserDataset(DSSMUserDatasetBase):
130167
"""
131168
Torch dataset wrapper for `rectools.dataset.dataset.Dataset`.
132169
Implements `torch.utils.data.Dataset` for subsequent usage with
@@ -143,6 +180,8 @@ def __init__(
143180
interactions: sparse.csr_matrix,
144181
keep_users: tp.Optional[tp.Sequence[int]] = None,
145182
):
183+
if users.shape[0] != interactions.shape[0]:
184+
raise ValueError("Number of rows in user features matrix and in interactions matrix must be the same")
146185
if keep_users is not None:
147186
self.users = users[keep_users]
148187
self.interactions = interactions[keep_users]
@@ -152,15 +191,15 @@ def __init__(
152191

153192
@classmethod
154193
def from_dataset(
155-
cls: tp.Type[UD],
194+
cls: tp.Type[DSSMUserDatasetT],
156195
dataset: Dataset,
157196
keep_users: tp.Optional[tp.Sequence[int]] = None,
158-
) -> UD:
197+
) -> DSSMUserDatasetT:
159198
# We take all features here since this dataset is used for recommend only, not for fit
160199
if dataset.user_features is not None:
161200
return cls(
162201
dataset.user_features.get_sparse(),
163-
dataset.get_user_item_matrix(),
202+
dataset.get_user_item_matrix(include_warm_users=True),
164203
keep_users,
165204
)
166205
raise AttributeError("User features attribute of dataset could not be None")

rectools/models/base.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def recommend(
166166
reco_warm = self._recommend_u2i_warm(warm_user_ids, dataset, k, sorted_item_ids_to_recommend)
167167
else:
168168
# TODO: use correct types for numpy arrays and stop ignoring
169-
reco_warm = self._recommend_cold(warm_user_ids, k, sorted_item_ids_to_recommend) # type: ignore
169+
reco_warm = self._recommend_cold(
170+
warm_user_ids, dataset, k, sorted_item_ids_to_recommend
171+
) # type: ignore
170172
if cold_user_ids.size > 0:
171-
reco_cold = self._recommend_cold(cold_user_ids, k, sorted_item_ids_to_recommend)
173+
reco_cold = self._recommend_cold(cold_user_ids, dataset, k, sorted_item_ids_to_recommend)
172174

173175
reco_hot = self._adjust_reco_types(reco_hot)
174176
reco_warm = self._adjust_reco_types(reco_warm)
@@ -285,11 +287,11 @@ def recommend_to_items( # pylint: disable=too-many-branches
285287
else:
286288
# TODO: use correct types for numpy arrays and stop ignoring
287289
reco_warm = self._recommend_cold(
288-
warm_target_ids, requested_k, sorted_item_ids_to_recommend
290+
warm_target_ids, dataset, requested_k, sorted_item_ids_to_recommend
289291
) # type: ignore
290292
if cold_target_ids.size > 0:
291293
# We intentionally request `k` and not `requested_k` here since we're not going to filter cold reco later
292-
reco_cold = self._recommend_cold(cold_target_ids, k, sorted_item_ids_to_recommend)
294+
reco_cold = self._recommend_cold(cold_target_ids, dataset, k, sorted_item_ids_to_recommend)
293295

294296
reco_hot = self._adjust_reco_types(reco_hot)
295297
reco_warm = self._adjust_reco_types(reco_warm)
@@ -459,7 +461,11 @@ def _make_reco_table(cls, reco: RecoTriplet, target_col: str, add_rank_col: bool
459461
return df
460462

461463
def _recommend_cold(
462-
self, target_ids: AnyIdsArray, k: int, sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray]
464+
self,
465+
target_ids: AnyIdsArray,
466+
dataset: Dataset,
467+
k: int,
468+
sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],
463469
) -> SemiInternalRecoTriplet:
464470
raise NotImplementedError()
465471

@@ -509,16 +515,20 @@ class FixedColdRecoModelMixin:
509515
"""
510516

511517
def _recommend_cold(
512-
self, target_ids: AnyIdsArray, k: int, sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray]
518+
self,
519+
target_ids: AnyIdsArray,
520+
dataset: Dataset,
521+
k: int,
522+
sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],
513523
) -> SemiInternalRecoTriplet:
514-
item_ids, scores = self._get_cold_reco(k, sorted_item_ids_to_recommend)
524+
item_ids, scores = self._get_cold_reco(dataset, k, sorted_item_ids_to_recommend)
515525
reco_target_ids = np.repeat(target_ids, len(item_ids))
516526
reco_item_ids = np.tile(item_ids, len(target_ids))
517527
reco_scores = np.tile(scores, len(target_ids))
518528

519529
return reco_target_ids, reco_item_ids, reco_scores
520530

521531
def _get_cold_reco(
522-
self, k: int, sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray]
532+
self, dataset: Dataset, k: int, sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray]
523533
) -> tp.Tuple[InternalIds, Scores]:
524534
raise NotImplementedError()

0 commit comments

Comments
 (0)