Skip to content

Commit eabfad1

Browse files
authored
Fix Implicit ALS matrix zero assignment size on GPU (#228)
Fixed ValueError for `ImplicitALSWrapperModel` with GPU
1 parent 75a30c1 commit eabfad1

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
- Support `fit_partial()` for LightFM ([#223](https://github.com/MobileTeleSystems/RecTools/pull/223))
2323
- LightFM Python 3.12+ support ([#224](https://github.com/MobileTeleSystems/RecTools/pull/224))
2424

25+
### Fixed
26+
- Fix Implicit ALS matrix zero assignment size ([#228](https://github.com/MobileTeleSystems/RecTools/pull/228))
27+
2528
### Removed
2629
- Python 3.8 support ([#222](https://github.com/MobileTeleSystems/RecTools/pull/222))
2730

rectools/models/implicit_als.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,8 @@ def _fit_combined_factors_on_gpu_inplace(
610610
model._item_norms_host = model._user_norms_host = None # pylint: disable=protected-access
611611
model._YtY = model._XtX = None # pylint: disable=protected-access
612612

613-
_YtY = implicit.gpu.Matrix.zeros(model.factors, model.factors)
614-
_XtX = implicit.gpu.Matrix.zeros(model.factors, model.factors)
613+
_YtY = implicit.gpu.Matrix.zeros(*item_factors.shape)
614+
_XtX = implicit.gpu.Matrix.zeros(*user_factors.shape)
615615

616616
for _ in tqdm(range(iterations), disable=verbose == 0):
617617

0 commit comments

Comments
 (0)