Skip to content

Sliding window bug #20

@hankcs

Description

@hankcs

Hi, there seems to be a bug in the calculation of final_window_start:

# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.
# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2
stride_offset = stride // 2 + self.start_tokens
first_window = list(range(stride_offset))
max_context_windows = [i for i in range(full_seq_len)
if stride_offset - 1 < i % self.max_pieces < stride_offset + stride]
final_window_start = full_seq_len - (full_seq_len % self.max_pieces) + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))
select_indices = first_window + max_context_windows + final_window

On the test case from your comment, final_window_start is greater than full_seq_len:

full_seq_len = 16
max_pieces = 8
start_tokens = 1
end_tokens = 1

# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
#  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.

# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (max_pieces - start_tokens - end_tokens) // 2
stride_offset = stride // 2 + start_tokens

first_window = list(range(stride_offset))

max_context_windows = [i for i in range(full_seq_len)
                       if stride_offset - 1 < i % max_pieces < stride_offset + stride]

final_window_start = full_seq_len - (full_seq_len % max_pieces) + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))

select_indices = first_window + max_context_windows + final_window
print(select_indices)

Output is [0, 1, 2, 3, 4, 10, 11, 12] and [14, 15] is missing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions