|
4 | 4 | import pandas as pd |
5 | 5 | import polars as pl |
6 | 6 |
|
7 | | -from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame |
| 7 | +from replay.utils import ( |
| 8 | + PYSPARK_AVAILABLE, |
| 9 | + DataFrameLike, |
| 10 | + PandasDataFrame, |
| 11 | + PolarsDataFrame, |
| 12 | + SparkDataFrame, |
| 13 | +) |
8 | 14 |
|
9 | 15 | from .base_splitter import Splitter |
10 | 16 |
|
@@ -118,14 +124,12 @@ def __init__( |
118 | 124 | session_id_processing_strategy: str = "test", |
119 | 125 | ): |
120 | 126 | """ |
121 | | - :param N: Array of interactions/timedelta to split. |
| 127 | + :param N: Number of last interactions or size of the time window in seconds |
122 | 128 | :param divide_column: Name of column for dividing |
123 | 129 | in dataframe, default: ``query_id``. |
124 | | - :param time_column_format: Format of time_column, |
125 | | - needs for convert time_column into unix_timestamp type. |
126 | | - If strategy is set to 'interactions', then you can omit this parameter. |
127 | | - If time_column has already transformed into unix_timestamp type, |
128 | | - then you can omit this parameter. |
| 130 | + :param time_column_format: Format of the timestamp column, |
| 131 | + used for converting string dates to a numerical timestamp when strategy is 'timedelta'. |
| 132 | + If the column is already a datetime object or a numerical timestamp, this parameter is ignored. |
129 | 133 | default: ``yyyy-MM-dd HH:mm:ss`` |
130 | 134 | :param strategy: Defines the type of data splitting. |
131 | 135 | Must be ``interactions`` or ``timedelta``. |
@@ -223,7 +227,8 @@ def _to_unix_timestamp_spark(self, interactions: SparkDataFrame) -> SparkDataFra |
223 | 227 | time_column_type = dict(interactions.dtypes)[self.timestamp_column] |
224 | 228 | if time_column_type == "date": |
225 | 229 | interactions = interactions.withColumn( |
226 | | - self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.time_column_format) |
| 230 | + self.timestamp_column, |
| 231 | + sf.unix_timestamp(self.timestamp_column, self.time_column_format), |
227 | 232 | ) |
228 | 233 |
|
229 | 234 | return interactions |
@@ -260,7 +265,8 @@ def _partial_split_interactions_spark( |
260 | 265 | self, interactions: SparkDataFrame, n: int |
261 | 266 | ) -> Tuple[SparkDataFrame, SparkDataFrame]: |
262 | 267 | interactions = interactions.withColumn( |
263 | | - "count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)) |
| 268 | + "count", |
| 269 | + sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)), |
264 | 270 | ) |
265 | 271 | # float(n) - because DataFrame.filter is changing order |
266 | 272 | # of sorted DataFrame to descending |
@@ -317,7 +323,8 @@ def _partial_split_timedelta_spark( |
317 | 323 | self, interactions: SparkDataFrame, timedelta: int |
318 | 324 | ) -> Tuple[SparkDataFrame, SparkDataFrame]: |
319 | 325 | inter_with_max_time = interactions.withColumn( |
320 | | - "max_timestamp", sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)) |
| 326 | + "max_timestamp", |
| 327 | + sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)), |
321 | 328 | ) |
322 | 329 | inter_with_diff = inter_with_max_time.withColumn( |
323 | 330 | "diff_timestamp", sf.col("max_timestamp") - sf.col(self.timestamp_column) |
|
0 commit comments