Skip to content

Commit 859d429

Browse files
committed
Add tests for vq functions
1 parent 4c87b15 commit 859d429

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ nosetests.xml
4545
coverage.xml
4646
*.cover
4747
.hypothesis/
48+
.pytest_cache/
4849

4950
# Translations
5051
*.mo

functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,8 @@ def backward(ctx, grad_output):
5858
grad_codebook = torch.zeros_like(codebook)
5959
grad_codebook.index_add_(0, indices, grad_output_flatten)
6060

61-
return (grad_inputs, grad_codebook)
61+
return (grad_inputs, grad_codebook)
62+
63+
vq = VectorQuantization.apply
64+
vq_st = VectorQuantizationStraightThrough.apply
65+
__all__ = [vq, vq_st]

test_functions.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
3+
import numpy as np
4+
import torch
5+
6+
from functions import vq, vq_st
7+
8+
def test_vq_shape():
9+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
10+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
11+
indices = vq(inputs, codebook)
12+
13+
assert indices.size() == (2, 3, 5)
14+
assert not indices.requires_grad
15+
assert indices.dtype == torch.int64
16+
17+
def test_vq():
18+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
19+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
20+
indices = vq(inputs, codebook)
21+
22+
differences = inputs.unsqueeze(3) - codebook
23+
distances = torch.norm(differences, p=2, dim=4)
24+
25+
_, indices_torch = torch.min(distances, dim=3)
26+
27+
assert np.allclose(indices.numpy(), indices_torch.numpy())
28+
29+
def test_vq_st_shape():
30+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
31+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
32+
codes = vq_st(inputs, codebook)
33+
34+
assert codes.size() == (2, 3, 5, 7)
35+
assert codes.requires_grad
36+
assert codes.dtype == torch.float32
37+
38+
def test_vq_st_gradient1():
39+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
40+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
41+
codes = vq_st(inputs, codebook)
42+
43+
grad_output = torch.rand((2, 3, 5, 7))
44+
grad_inputs, = torch.autograd.grad(codes, inputs,
45+
grad_outputs=[grad_output])
46+
47+
# Straight-through estimator
48+
assert grad_inputs.size() == (2, 3, 5, 7)
49+
assert np.allclose(grad_output.numpy(), grad_inputs.numpy())
50+
51+
def test_vq_st_gradient2():
52+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
53+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
54+
codes = vq_st(inputs, codebook)
55+
56+
indices = vq(inputs, codebook)
57+
codes_torch = torch.embedding(codebook, indices, padding_idx=-1,
58+
scale_grad_by_freq=False, sparse=False)
59+
60+
grad_output = torch.rand((2, 3, 5, 7), dtype=torch.float32)
61+
grad_codebook, = torch.autograd.grad(codes, codebook,
62+
grad_outputs=[grad_output])
63+
grad_codebook_torch, = torch.autograd.grad(codes_torch, codebook,
64+
grad_outputs=[grad_output])
65+
66+
# Gradient is the same as torch.embedding function
67+
assert grad_codebook.size() == (11, 7)
68+
assert np.allclose(grad_codebook.numpy(), grad_codebook_torch.numpy())

0 commit comments

Comments
 (0)