Skip to content

Commit e7766ef

Browse files
authored
[Fix] Fix mask order in sequence (#310)
1 parent 33736a9 commit e7766ef

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

libreco/batch/sequence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_interacted_seq(
5757
# first item has no historical interaction, fill in with pad_index
5858
batch_interacted_len.append(1.0)
5959
elif position < num:
60-
batch_interacted[j, -position:] = consumed_items[:position]
60+
batch_interacted[j, :position] = consumed_items[:position]
6161
batch_interacted_len.append(float(position))
6262
else:
6363
if mode == "recent":
@@ -79,7 +79,7 @@ def get_recent_seqs(n_users, user_consumed, pad_index, max_seq_len, dtype):
7979
u_consumed_items = user_consumed[u]
8080
u_items_len = len(u_consumed_items)
8181
if u_items_len < max_seq_len:
82-
recent_seqs[u, -u_items_len:] = u_consumed_items
82+
recent_seqs[u, :u_items_len] = u_consumed_items
8383
recent_seq_lens.append(float(u_items_len))
8484
else:
8585
recent_seqs[u] = u_consumed_items[-max_seq_len:]

libreco/recommendation/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def process_embed_seq(model, user_id, seq, inner_id):
3838
def build_rec_seq(seq, model, inner_id, repeat=False):
3939
seq, seq_len = _extract_seq(seq, model, inner_id)
4040
recent_seq = np.full((1, model.max_seq_len), model.n_items, dtype=np.int32)
41-
recent_seq[0, -seq_len:] = seq[-seq_len:]
41+
recent_seq[0, :seq_len] = seq[-seq_len:]
4242
seq_len = np.array([seq_len], dtype=np.float32)
4343
if repeat:
4444
recent_seq = np.repeat(recent_seq, model.n_items, axis=0)

libreco/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class SageModels(StrEnum):
8484
@unique
8585
class UserEmbedModels(StrEnum):
8686
"""Models can only generate user embeddings dynamically."""
87+
8788
YOUTUBERETRIEVAL = "YouTubeRetrieval"
8889
RNN4REC = "RNN4Rec"
8990
CASER = "Caser"

0 commit comments

Comments
 (0)