Skip to content

Commit 7ee699c

Browse files
committed
Return indices in vq_st
1 parent d4fbc92 commit 7ee699c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def forward(ctx, inputs, codebook):
4040
index=indices_flatten)
4141
codes = codes_flatten.view_as(inputs)
4242

43-
return codes
43+
return (codes, indices_flatten)
4444

4545
@staticmethod
46-
def backward(ctx, grad_output):
46+
def backward(ctx, grad_output, grad_indices):
4747
grad_inputs, grad_codebook = None, None
4848

4949
if ctx.needs_input_grad[0]:

modules.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,15 @@ def forward(self, z_e_x):
8080

8181
def straight_through(self, z_e_x):
8282
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
83-
z_q_x_ = vq_st(z_e_x_, self.embedding.weight)
83+
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight)
8484
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
85-
return z_q_x
85+
86+
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
87+
dim=0, index=indices)
88+
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
89+
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
90+
91+
return z_q_x, z_q_x_bar
8692

8793

8894
class ResBlock(nn.Module):
@@ -140,8 +146,8 @@ def decode(self, latents):
140146

141147
def forward(self, x):
142148
z_e_x = self.encoder(x)
143-
z_q_x = self.codebook.straight_through(z_e_x)
144-
x_tilde = self.decoder(z_q_x)
149+
z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x)
150+
x_tilde = self.decoder(z_q_x_st)
145151
return x_tilde, z_e_x, z_q_x
146152

147153

0 commit comments

Comments
 (0)