Skip to content

Commit bd8a6bb

Browse files
Andrey-Matyashovmonkey0head
authored andcommitted
fixes #82
1 parent ddd2e2f commit bd8a6bb

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

replay/models/nn/sequential/sasrec/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def _layers_stacker(self, num_blocks: int, layer_class: Any, *args, **kwargs) ->
442442

443443
class SasRecNormalizer(torch.nn.Module):
444444
"""
445-
SasRec notmilization layers
445+
SasRec normalization layers
446446
447447
Link: https://arxiv.org/pdf/1808.09781.pdf
448448
"""

replay/splitters/cold_user_random_splitter.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ def __init__(
3838
item_column: Optional[str] = "item_id",
3939
):
4040
"""
41-
:param test_size: fraction of users to be in test
42-
:param drop_cold_items: flag to drop cold items from test
43-
:param drop_cold_users: flag to drop cold users from test
44-
:param seed: random seed
45-
:param query_column: query id column name
46-
:param item_column: item id column name
41+
:param test_size: The proportion of users to allocate to the test set.
42+
Must be a float between 0.0 and 1.0.
43+
:param drop_cold_items: Drop items from test DataFrame
44+
which are not in train DataFrame, default: False.
45+
:param seed: Seed for the random number generator to ensure
46+
reproducibility of the split, default: None.
47+
:param query_column: Name of query interaction column.
48+
default: ``query_id``.
49+
:param item_column: Name of item interaction column.
50+
default: ``item_id``.
4751
"""
4852
super().__init__(
4953
drop_cold_items=drop_cold_items,
@@ -81,7 +85,9 @@ def _core_split_spark(
8185
seed=self.seed,
8286
)
8387
interactions = interactions.join(
84-
train_users.withColumn("is_test", sf.lit(False)), on=self.query_column, how="left"
88+
train_users.withColumn("is_test", sf.lit(False)),
89+
on=self.query_column,
90+
how="left",
8591
).na.fill({"is_test": True})
8692

8793
train = interactions.filter(~sf.col("is_test")).drop("is_test")

replay/splitters/last_n_splitter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import pandas as pd
55
import polars as pl
66

7-
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
7+
from replay.utils import (
8+
PYSPARK_AVAILABLE,
9+
DataFrameLike,
10+
PandasDataFrame,
11+
PolarsDataFrame,
12+
SparkDataFrame,
13+
)
814

915
from .base_splitter import Splitter
1016

@@ -118,14 +124,12 @@ def __init__(
118124
session_id_processing_strategy: str = "test",
119125
):
120126
"""
121-
:param N: Array of interactions/timedelta to split.
127+
:param N: Number of last interactions or size of the time window in seconds
122128
:param divide_column: Name of column for dividing
123129
in dataframe, default: ``query_id``.
124-
:param time_column_format: Format of time_column,
125-
needs for convert time_column into unix_timestamp type.
126-
If strategy is set to 'interactions', then you can omit this parameter.
127-
If time_column has already transformed into unix_timestamp type,
128-
then you can omit this parameter.
130+
:param time_column_format: Format of the timestamp column,
131+
used for converting string dates to a numerical timestamp when strategy is 'timedelta'.
132+
If the column is already a datetime object or a numerical timestamp, this parameter is ignored.
129133
default: ``yyyy-MM-dd HH:mm:ss``
130134
:param strategy: Defines the type of data splitting.
131135
Must be ``interactions`` or ``timedelta``.
@@ -223,7 +227,8 @@ def _to_unix_timestamp_spark(self, interactions: SparkDataFrame) -> SparkDataFra
223227
time_column_type = dict(interactions.dtypes)[self.timestamp_column]
224228
if time_column_type == "date":
225229
interactions = interactions.withColumn(
226-
self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.time_column_format)
230+
self.timestamp_column,
231+
sf.unix_timestamp(self.timestamp_column, self.time_column_format),
227232
)
228233

229234
return interactions
@@ -260,7 +265,8 @@ def _partial_split_interactions_spark(
260265
self, interactions: SparkDataFrame, n: int
261266
) -> Tuple[SparkDataFrame, SparkDataFrame]:
262267
interactions = interactions.withColumn(
263-
"count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column))
268+
"count",
269+
sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
264270
)
265271
# float(n) - because DataFrame.filter is changing order
266272
# of sorted DataFrame to descending
@@ -317,7 +323,8 @@ def _partial_split_timedelta_spark(
317323
self, interactions: SparkDataFrame, timedelta: int
318324
) -> Tuple[SparkDataFrame, SparkDataFrame]:
319325
inter_with_max_time = interactions.withColumn(
320-
"max_timestamp", sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column))
326+
"max_timestamp",
327+
sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
321328
)
322329
inter_with_diff = inter_with_max_time.withColumn(
323330
"diff_timestamp", sf.col("max_timestamp") - sf.col(self.timestamp_column)

0 commit comments

Comments
 (0)