-
Notifications
You must be signed in to change notification settings - Fork 58
Open
Description
Hi, there seems to be a bug in the calculation of final_window_start:
udify/udify/modules/bert_pretrained.py
Lines 488 to 509 in cbabef6
| # Next, select indices of the sequence such that it will result in embeddings representing the original | |
| # sentence. To capture maximal context, the indices will be the middle part of each embedded window | |
| # sub-sequence (plus any leftover start and final edge windows), e.g., | |
| # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | |
| # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" | |
| # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start | |
| # and final windows with indices [0, 1] and [14, 15] respectively. | |
| # Find the stride as half the max pieces, ignoring the special start and end tokens | |
| # Calculate an offset to extract the centermost embeddings of each window | |
| stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2 | |
| stride_offset = stride // 2 + self.start_tokens | |
| first_window = list(range(stride_offset)) | |
| max_context_windows = [i for i in range(full_seq_len) | |
| if stride_offset - 1 < i % self.max_pieces < stride_offset + stride] | |
| final_window_start = full_seq_len - (full_seq_len % self.max_pieces) + stride_offset + stride | |
| final_window = list(range(final_window_start, full_seq_len)) | |
| select_indices = first_window + max_context_windows + final_window |
On the test case from your comment, final_window_start is greater than full_seq_len:
full_seq_len = 16
max_pieces = 8
start_tokens = 1
end_tokens = 1
# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.
# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (max_pieces - start_tokens - end_tokens) // 2
stride_offset = stride // 2 + start_tokens
first_window = list(range(stride_offset))
max_context_windows = [i for i in range(full_seq_len)
if stride_offset - 1 < i % max_pieces < stride_offset + stride]
final_window_start = full_seq_len - (full_seq_len % max_pieces) + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))
select_indices = first_window + max_context_windows + final_window
print(select_indices)Output is [0, 1, 2, 3, 4, 10, 11, 12] and [14, 15] is missing.
Metadata
Metadata
Assignees
Labels
No labels