Skip to content

Commit 34fd9b5

Browse files
committed
Fix tests
1 parent 7ee699c commit 34fd9b5

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def forward(ctx, inputs, codebook):
3535
indices = vq(inputs, codebook)
3636
indices_flatten = indices.view(-1)
3737
ctx.save_for_backward(indices_flatten, codebook)
38+
ctx.mark_non_differentiable(indices_flatten)
3839

3940
codes_flatten = torch.index_select(codebook, dim=0,
4041
index=indices_flatten)

test_functions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,20 @@ def test_vq():
2929
def test_vq_st_shape():
3030
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
3131
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
32-
codes = vq_st(inputs, codebook)
32+
codes, indices = vq_st(inputs, codebook)
3333

3434
assert codes.size() == (2, 3, 5, 7)
3535
assert codes.requires_grad
3636
assert codes.dtype == torch.float32
3737

38+
assert indices.size() == (2 * 3 * 5,)
39+
assert not indices.requires_grad
40+
assert indices.dtype == torch.int64
41+
3842
def test_vq_st_gradient1():
3943
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
4044
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
41-
codes = vq_st(inputs, codebook)
45+
codes, _ = vq_st(inputs, codebook)
4246

4347
grad_output = torch.rand((2, 3, 5, 7))
4448
grad_inputs, = torch.autograd.grad(codes, inputs,
@@ -51,7 +55,7 @@ def test_vq_st_gradient1():
5155
def test_vq_st_gradient2():
5256
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
5357
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
54-
codes = vq_st(inputs, codebook)
58+
codes, _ = vq_st(inputs, codebook)
5559

5660
indices = vq(inputs, codebook)
5761
codes_torch = torch.embedding(codebook, indices, padding_idx=-1,

0 commit comments

Comments
 (0)