Skip to content

Commit d4fbc92

Browse files
committed
Fix contiguous in VQEmbedding straight-through
1 parent fa4504c commit d4fbc92

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def backward(ctx, grad_output):
5454
indices, codebook = ctx.saved_tensors
5555
embedding_size = codebook.size(1)
5656

57-
grad_output_flatten = grad_output.view(-1, embedding_size)
57+
grad_output_flatten = (grad_output.contiguous()
58+
.view(-1, embedding_size))
5859
grad_codebook = torch.zeros_like(codebook)
5960
grad_codebook.index_add_(0, indices, grad_output_flatten)
6061

modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def forward(self, z_e_x):
8181
def straight_through(self, z_e_x):
8282
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
8383
z_q_x_ = vq_st(z_e_x_, self.embedding.weight)
84-
z_q_x = z_q_x_.permute(0, 3, 1, 2)
84+
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
8585
return z_q_x
8686

8787

0 commit comments

Comments
 (0)