Skip to content

Commit ab469d6

Browse files
authored
[Fix] Saving all past behaviors during merging (#320)
* save consumed using pickle * fix saving all past behaviors during merging * ignore ruff `RUF012`
1 parent 1de7c29 commit ab469d6

File tree

7 files changed

+260
-59
lines changed

7 files changed

+260
-59
lines changed

libreco/data/consumed.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,17 @@ def _remove_duplicates(user_consumed, item_consumed):
2929
return user_dedup, item_dedup
3030

3131

32-
def update_consumed(new_data_info, data_info, merge_behavior):
32+
def update_consumed(
33+
user_indices, item_indices, n_users, n_items, old_info, merge_behavior
34+
):
35+
user_consumed, item_consumed = interaction_consumed(user_indices, item_indices)
3336
if merge_behavior:
34-
new_data_info.user_consumed = _merge_dedup(
35-
new_data_info.user_consumed, new_data_info.n_users, data_info.user_consumed
36-
)
37-
new_data_info.item_consumed = _merge_dedup(
38-
new_data_info.item_consumed, new_data_info.n_items, data_info.item_consumed
39-
)
37+
user_consumed = _merge_dedup(user_consumed, n_users, old_info.user_consumed)
38+
item_consumed = _merge_dedup(item_consumed, n_items, old_info.item_consumed)
4039
else:
41-
new_data_info.user_consumed = _fill_empty(
42-
new_data_info.user_consumed, new_data_info.n_users, data_info.user_consumed
43-
)
44-
new_data_info.item_consumed = _fill_empty(
45-
new_data_info.item_consumed, new_data_info.n_items, data_info.item_consumed
46-
)
47-
return new_data_info
40+
user_consumed = _fill_empty(user_consumed, n_users, old_info.user_consumed)
41+
item_consumed = _fill_empty(item_consumed, n_items, old_info.item_consumed)
42+
return user_consumed, item_consumed
4843

4944

5045
def _merge_dedup(new_consumed, num, old_consumed):

libreco/data/data_info.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Classes for Storing Various Data Information."""
22
import inspect
33
import json
4-
import os
4+
import pickle
55
from collections import namedtuple
66
from dataclasses import dataclass
7+
from pathlib import Path
78
from typing import Any, Dict, Iterable, List
89

910
import numpy as np
1011
import pandas as pd
1112

12-
from .consumed import interaction_consumed
1313
from ..feature.update import (
1414
get_row_id_masks,
1515
update_new_dense_feats,
@@ -69,10 +69,10 @@ class DataInfo:
6969
Unique sparse features for all items in train data.
7070
item_dense_unique : numpy.ndarray or None, default: None
7171
Unique dense features for all items in train data.
72-
user_indices : numpy.ndarray or None, default: None
73-
Mapped inner user indices from train data.
74-
item_indices : numpy.ndarray or None, default: None
75-
Mapped inner item indices from train data.
72+
user_consumed : dict of {int : list} or None, default: None
73+
All consumed items by each user.
74+
item_consumed : dict of {int : list} or None, default: None
75+
All consumed users by each item.
7676
user_unique_vals : numpy.ndarray or None, default: None
7777
All the unique users in train data.
7878
item_unique_vals : numpy.ndarray or None, default: None
@@ -110,8 +110,8 @@ def __init__(
110110
user_dense_unique=None,
111111
item_sparse_unique=None,
112112
item_dense_unique=None,
113-
user_indices=None,
114-
item_indices=None,
113+
user_consumed=None,
114+
item_consumed=None,
115115
user_unique_vals=None,
116116
item_unique_vals=None,
117117
sparse_unique_vals=None,
@@ -126,9 +126,8 @@ def __init__(
126126
self.user_dense_unique = user_dense_unique
127127
self.item_sparse_unique = item_sparse_unique
128128
self.item_dense_unique = item_dense_unique
129-
self.user_consumed, self.item_consumed = interaction_consumed(
130-
user_indices, item_indices
131-
)
129+
self.user_consumed = user_consumed
130+
self.item_consumed = item_consumed
132131
self.user_unique_vals = user_unique_vals
133132
self.item_unique_vals = item_unique_vals
134133
self.sparse_unique_vals = sparse_unique_vals
@@ -440,27 +439,30 @@ def save(self, path, model_name):
440439
model_name : str
441440
Name of the saved file.
442441
"""
443-
if not os.path.isdir(path):
442+
path = Path(path)
443+
if not path.is_dir():
444444
print(f"file folder {path} doesn't exists, creating a new one...")
445-
os.makedirs(path)
445+
path.mkdir()
446446
if self.col_name_mapping is not None:
447-
name_mapping_path = os.path.join(
448-
path, f"{model_name}_data_info_name_mapping.json"
449-
)
450-
with open(name_mapping_path, "w") as f:
447+
with open(path / f"{model_name}_data_info_name_mapping.json", "w") as f:
451448
json.dump(
452449
self.all_args["col_name_mapping"],
453450
f,
454451
separators=(",", ":"),
455452
indent=4,
456453
)
454+
if self.user_consumed is not None:
455+
with open(path / f"{model_name}_user_consumed.pkl", "wb") as f:
456+
pickle.dump(self.user_consumed, f, protocol=pickle.HIGHEST_PROTOCOL)
457+
if self.item_consumed is not None:
458+
with open(path / f"{model_name}_item_consumed.pkl", "wb") as f:
459+
pickle.dump(self.item_consumed, f, protocol=pickle.HIGHEST_PROTOCOL)
457460

458-
other_path = os.path.join(path, f"{model_name}_data_info")
459461
hparams = dict()
460462
arg_names = inspect.signature(self.__init__).parameters.keys()
461463
for arg in arg_names:
462464
if (
463-
arg == "col_name_mapping"
465+
arg in ("col_name_mapping", "user_consumed", "item_consumed")
464466
or arg not in self.all_args
465467
or self.all_args[arg] is None
466468
):
@@ -478,7 +480,7 @@ def save(self, path, model_name):
478480
else:
479481
hparams[arg] = self.all_args[arg]
480482

481-
np.savez_compressed(other_path, **hparams)
483+
np.savez_compressed(path / f"{model_name}_data_info", **hparams)
482484

483485
@classmethod
484486
def load(cls, path, model_name):
@@ -491,19 +493,26 @@ def load(cls, path, model_name):
491493
model_name : str
492494
Name of the saved file.
493495
"""
494-
if not os.path.exists(path):
496+
path = Path(path)
497+
if not path.exists():
495498
raise OSError(f"file folder {path} doesn't exists...")
496499

497500
hparams = dict()
498-
name_mapping_path = os.path.join(
499-
path, f"{model_name}_data_info_name_mapping.json"
500-
)
501-
if os.path.exists(name_mapping_path):
501+
name_mapping_path = path / f"{model_name}_data_info_name_mapping.json"
502+
if name_mapping_path.exists():
502503
with open(name_mapping_path, "r") as f:
503504
hparams["col_name_mapping"] = json.load(f)
504505

505-
other_path = os.path.join(path, f"{model_name}_data_info.npz")
506-
info = np.load(other_path, allow_pickle=True)
506+
user_consumed_path = path / f"{model_name}_user_consumed.pkl"
507+
if user_consumed_path.exists():
508+
with open(user_consumed_path, "rb") as f:
509+
hparams["user_consumed"] = pickle.load(f)
510+
item_consumed_path = path / f"{model_name}_item_consumed.pkl"
511+
if item_consumed_path.exists():
512+
with open(item_consumed_path, "rb") as f:
513+
hparams["item_consumed"] = pickle.load(f)
514+
515+
info = np.load(path / f"{model_name}_data_info.npz", allow_pickle=True)
507516
info = dict(info.items())
508517
for arg in info:
509518
if arg == "interaction_data":
@@ -556,6 +565,7 @@ def store_old_info(data_info):
556565
# multi_sparse case, second to last cols are redundant.
557566
# Used in `rebuild_tf_model`, `rebuild_torch_model`
558567
sparse_len.append(-1)
568+
559569
return OldInfo(
560570
data_info.n_users,
561571
data_info.n_items,

libreco/data/dataset.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from .consumed import update_consumed
7+
from .consumed import interaction_consumed, update_consumed
88
from .data_info import DataInfo, store_old_info
99
from .transformed import TransformedSet
1010
from ..feature.column_mapping import col_name2index
@@ -246,10 +246,11 @@ def build_trainset(cls, train_data, shuffle=False, seed=42):
246246
is_train=True,
247247
is_ordered=True,
248248
)
249+
user_consumed, item_consumed = interaction_consumed(user_indices, item_indices)
249250
data_info = DataInfo(
250251
interaction_data=train_data[["user", "item", "label"]],
251-
user_indices=user_indices,
252-
item_indices=item_indices,
252+
user_consumed=user_consumed,
253+
item_consumed=item_consumed,
253254
user_unique_vals=cls.user_unique_vals,
254255
item_unique_vals=cls.item_unique_vals,
255256
)
@@ -303,14 +304,22 @@ def merge_trainset(
303304
is_train=True,
304305
is_ordered=False,
305306
)
307+
user_consumed, item_consumed = update_consumed(
308+
user_indices,
309+
item_indices,
310+
len(cls.user_unique_vals),
311+
len(cls.item_unique_vals),
312+
data_info,
313+
merge_behavior,
314+
)
315+
306316
new_data_info = DataInfo(
307317
interaction_data=train_data[["user", "item", "label"]],
308-
user_indices=user_indices,
309-
item_indices=item_indices,
318+
user_consumed=user_consumed,
319+
item_consumed=item_consumed,
310320
user_unique_vals=cls.user_unique_vals,
311321
item_unique_vals=cls.item_unique_vals,
312322
)
313-
new_data_info = update_consumed(new_data_info, data_info, merge_behavior)
314323
new_data_info.old_info = store_old_info(data_info)
315324
cls.train_called = True
316325
return merge_transformed, new_data_info
@@ -511,15 +520,16 @@ def build_trainset(
511520
col_name_mapping["multi_sparse"] = multi_sparse_col_map(multi_sparse_col)
512521

513522
interaction_data = train_data[["user", "item", "label"]]
523+
user_consumed, item_consumed = interaction_consumed(user_indices, item_indices)
514524
data_info = DataInfo(
515525
col_name_mapping,
516526
interaction_data,
517527
user_sparse_unique,
518528
user_dense_unique,
519529
item_sparse_unique,
520530
item_dense_unique,
521-
user_indices,
522-
item_indices,
531+
user_consumed,
532+
item_consumed,
523533
cls.user_unique_vals,
524534
cls.item_unique_vals,
525535
cls.sparse_unique_vals,
@@ -632,15 +642,24 @@ def merge_trainset(
632642
)
633643

634644
interaction_data = train_data[["user", "item", "label"]]
645+
user_consumed, item_consumed = update_consumed(
646+
user_indices,
647+
item_indices,
648+
len(cls.user_unique_vals),
649+
len(cls.item_unique_vals),
650+
data_info,
651+
merge_behavior,
652+
)
653+
635654
new_data_info = DataInfo(
636655
data_info.col_name_mapping,
637656
interaction_data,
638657
user_sparse_unique,
639658
user_dense_unique,
640659
item_sparse_unique,
641660
item_dense_unique,
642-
user_indices,
643-
item_indices,
661+
user_consumed,
662+
item_consumed,
644663
cls.user_unique_vals,
645664
cls.item_unique_vals,
646665
cls.sparse_unique_vals,
@@ -649,7 +668,6 @@ def merge_trainset(
649668
cls.multi_sparse_unique_vals,
650669
multi_sparse_info,
651670
)
652-
new_data_info = update_consumed(new_data_info, data_info, merge_behavior)
653671
new_data_info.old_info = store_old_info(data_info)
654672
cls.train_called = True
655673
return merge_transformed, new_data_info

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ filterwarnings = [
9090
line-length = 88
9191
target-version = "py38"
9292
show-source = true
93-
ignore = ["E501"]
93+
ignore = ["E501", "RUF012"]
9494
select = [
9595
# pyflakes
9696
"F",

tests/retrain/test_tfmodel_retrain_feat.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_tfmodel_retrain_feat():
9595
tf.compat.v1.reset_default_graph()
9696
new_data_info = DataInfo.load(SAVE_PATH, model_name="din_model")
9797

98-
# use second half data as second training part
99-
second_half_data = all_data[(len(all_data) // 2) :]
98+
# use first half of second half data as second training part
99+
second_half_data = all_data[(len(all_data) // 2) : (len(all_data) * 3 // 4)]
100100
train_data_orig, eval_data_orig = split_by_ratio_chrono(
101101
second_half_data, test_size=0.2
102102
)
@@ -158,4 +158,60 @@ def test_tfmodel_retrain_feat():
158158

159159
assert new_eval_result["roc_auc"] != eval_result["roc_auc"]
160160

161+
new_data_info.save(path=SAVE_PATH, model_name="din_model")
162+
new_model.save(
163+
path=SAVE_PATH, model_name="din_model", manual=True, inference_only=False
164+
)
165+
166+
# ========================== load and retrain 2 =============================
167+
tf.compat.v1.reset_default_graph()
168+
new_data_info = DataInfo.load(SAVE_PATH, model_name="din_model")
169+
170+
# use second half of second half data as second training part
171+
third_half_data = all_data[(len(all_data) * 3 // 4) :]
172+
train_data_orig, eval_data_orig = split_by_ratio_chrono(
173+
third_half_data, test_size=0.2
174+
)
175+
train_data, new_data_info = DatasetFeat.merge_trainset(
176+
train_data_orig, new_data_info, merge_behavior=True
177+
)
178+
eval_data = DatasetFeat.merge_evalset(eval_data_orig, new_data_info)
179+
print(new_data_info)
180+
181+
new_model = DIN(
182+
"ranking",
183+
new_data_info,
184+
loss_type="focal", # change loss
185+
embed_size=16,
186+
n_epochs=1,
187+
lr=1e-4,
188+
lr_decay=False,
189+
reg=None,
190+
batch_size=2048,
191+
hidden_units=(32, 16),
192+
recent_num=10,
193+
use_tf_attention=True,
194+
)
195+
new_model.rebuild_model(path=SAVE_PATH, model_name="din_model", full_assign=True)
196+
new_model.fit(
197+
train_data,
198+
neg_sampling=True,
199+
verbose=2,
200+
shuffle=True,
201+
eval_data=eval_data,
202+
metrics=[
203+
"loss",
204+
"balanced_accuracy",
205+
"roc_auc",
206+
"pr_auc",
207+
"precision",
208+
"recall",
209+
"map",
210+
"ndcg",
211+
],
212+
eval_user_num=20,
213+
)
214+
ptest_preds(new_model, "ranking", second_half_data, with_feats=False)
215+
ptest_recommends(new_model, new_data_info, second_half_data, with_feats=False)
216+
161217
remove_path(SAVE_PATH)

0 commit comments

Comments
 (0)