We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cb4702a commit 8d123c0Copy full SHA for 8d123c0
modules.py
@@ -80,7 +80,7 @@ def forward(self, z_e_x):
80
81
def straight_through(self, z_e_x):
82
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
83
- z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight)
+ z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
84
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
85
86
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
0 commit comments