Skip to content

Commit 1e9844b

Browse files
authored
Added HitRate metric (#124)
Added HitRate metric
1 parent 8eccee5 commit 1e9844b

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- Warm and cold users/items support in `cross_validate` ([#77](https://github.com/MobileTeleSystems/RecTools/pull/77))
1515
- [Breaking] Default value for train dataset type and params for user and item dataset types in `DSSMModel` ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
1616
- [Breaking] `n_factors` and `deterministic` params to `DSSMModel` ([#122](https://github.com/MobileTeleSystems/RecTools/pull/122))
17+
- Hit Rate metric ([#124](https://github.com/MobileTeleSystems/RecTools/pull/124))
1718

1819
### Changed
1920
- Changed the logic of choosing random sampler for `RandomModel` and increased the sampling speed ([#120](https://github.com/MobileTeleSystems/RecTools/pull/120))

rectools/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
`metrics.IntraListDiversity`
3434
`metrics.AvgRecPopularity`
3535
`metrics.Serendipity`
36+
`metrics.HitRate`
3637
3738
Tools
3839
-----
@@ -42,7 +43,7 @@
4243
`metrics.SparsePairwiseHammingDistanceCalculator`
4344
"""
4445

45-
from .classification import MCC, Accuracy, F1Beta, Precision, Recall
46+
from .classification import MCC, Accuracy, F1Beta, HitRate, Precision, Recall
4647
from .distances import (
4748
PairwiseDistanceCalculator,
4849
PairwiseHammingDistanceCalculator,
@@ -61,6 +62,7 @@
6162
"F1Beta",
6263
"Accuracy",
6364
"MCC",
65+
"HitRate",
6466
"MAP",
6567
"NDCG",
6668
"MRR",

rectools/metrics/classification.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,26 @@ def _calc_per_user_from_confusion_df(self, confusion_df: pd.DataFrame, catalog:
371371
return mcc
372372

373373

374+
@attr.s
375+
class HitRate(SimpleClassificationMetric):
376+
"""
377+
HitRate calculates the fraction of users for which the correct answer is included in the recommendation list.
378+
379+
The HitRate equals to ``1 if tp > 0, otherwise 0`` where
380+
- ``tp`` is the number of relevant recommendations
381+
among the first ``k`` items in recommendation list.
382+
383+
Parameters
384+
----------
385+
k : int
386+
Number of items in top of recommendations list that will be used to calculate metric.
387+
"""
388+
389+
def _calc_per_user_from_confusion_df(self, confusion_df: pd.DataFrame) -> pd.Series:
390+
hit_rate = (confusion_df[TP] > 0).astype(float)
391+
return hit_rate
392+
393+
374394
def calc_classification_metrics(
375395
metrics: tp.Dict[str, tp.Union[ClassificationMetric, SimpleClassificationMetric]],
376396
merged: pd.DataFrame,

tests/metrics/test_classification.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020

2121
from rectools import Columns
22-
from rectools.metrics import MCC, Accuracy, F1Beta, Precision, Recall
22+
from rectools.metrics import MCC, Accuracy, F1Beta, HitRate, Precision, Recall
2323
from rectools.metrics.base import MetricAtK
2424
from rectools.metrics.classification import ClassificationMetric, calc_classification_metrics
2525

@@ -155,3 +155,29 @@ def test_when_no_interactions(self) -> None:
155155
expected_metric_per_user,
156156
)
157157
assert np.isnan(self.metric.calc(RECO, EMPTY_INTERACTIONS, CATALOG))
158+
159+
160+
class TestHitRate:
161+
def setup(self) -> None:
162+
self.metric = HitRate(k=2)
163+
164+
def test_calc(self) -> None:
165+
166+
# tp = pd.Series([1, 1, 0, 0])
167+
# tn = pd.Series([6, 8, 7, 7])
168+
# fp = pd.Series([1, 1, 2, 2])
169+
# fn = pd.Series([2, 0, 1, 1])
170+
171+
expected_metric_per_user = pd.Series(
172+
[1, 1, 0, 0], index=pd.Series([1, 3, 4, 5], name=Columns.User, dtype=int), dtype=float
173+
)
174+
pd.testing.assert_series_equal(self.metric.calc_per_user(RECO, INTERACTIONS), expected_metric_per_user)
175+
assert self.metric.calc(RECO, INTERACTIONS) == expected_metric_per_user.mean()
176+
177+
def test_when_no_interactions(self) -> None:
178+
expected_metric_per_user = pd.Series(index=pd.Series(name=Columns.User, dtype=int), dtype=np.float64)
179+
pd.testing.assert_series_equal(
180+
self.metric.calc_per_user(RECO, EMPTY_INTERACTIONS),
181+
expected_metric_per_user,
182+
)
183+
assert np.isnan(self.metric.calc(RECO, EMPTY_INTERACTIONS))

tests/metrics/test_scoring.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
NDCG,
2525
Accuracy,
2626
AvgRecPopularity,
27+
HitRate,
2728
IntraListDiversity,
2829
MeanInvUserFreq,
2930
PairwiseHammingDistanceCalculator,
@@ -72,6 +73,7 @@ def test_success(self) -> None:
7273
"prec@2": Precision(k=2),
7374
"recall@1": Recall(k=1),
7475
"accuracy@1": Accuracy(k=1),
76+
"hitrate@1": HitRate(k=1),
7577
"map@1": MAP(k=1),
7678
"map@2": MAP(k=2),
7779
"ndcg@1": NDCG(k=1, log_base=3),
@@ -89,6 +91,7 @@ def test_success(self) -> None:
8991
"prec@2": 0.375,
9092
"recall@1": 0.125,
9193
"accuracy@1": 0.825,
94+
"hitrate@1": 0.25,
9295
"map@1": 0.125,
9396
"map@2": 0.375,
9497
"ndcg@1": 0.25,

0 commit comments

Comments
 (0)