@@ -29,16 +29,20 @@ def test_vq():
29
29
def test_vq_st_shape ():
30
30
inputs = torch .rand ((2 , 3 , 5 , 7 ), dtype = torch .float32 , requires_grad = True )
31
31
codebook = torch .rand ((11 , 7 ), dtype = torch .float32 , requires_grad = True )
32
- codes = vq_st (inputs , codebook )
32
+ codes , indices = vq_st (inputs , codebook )
33
33
34
34
assert codes .size () == (2 , 3 , 5 , 7 )
35
35
assert codes .requires_grad
36
36
assert codes .dtype == torch .float32
37
37
38
+ assert indices .size () == (2 * 3 * 5 ,)
39
+ assert not indices .requires_grad
40
+ assert indices .dtype == torch .int64
41
+
38
42
def test_vq_st_gradient1 ():
39
43
inputs = torch .rand ((2 , 3 , 5 , 7 ), dtype = torch .float32 , requires_grad = True )
40
44
codebook = torch .rand ((11 , 7 ), dtype = torch .float32 , requires_grad = True )
41
- codes = vq_st (inputs , codebook )
45
+ codes , _ = vq_st (inputs , codebook )
42
46
43
47
grad_output = torch .rand ((2 , 3 , 5 , 7 ))
44
48
grad_inputs , = torch .autograd .grad (codes , inputs ,
@@ -51,7 +55,7 @@ def test_vq_st_gradient1():
51
55
def test_vq_st_gradient2 ():
52
56
inputs = torch .rand ((2 , 3 , 5 , 7 ), dtype = torch .float32 , requires_grad = True )
53
57
codebook = torch .rand ((11 , 7 ), dtype = torch .float32 , requires_grad = True )
54
- codes = vq_st (inputs , codebook )
58
+ codes , _ = vq_st (inputs , codebook )
55
59
56
60
indices = vq (inputs , codebook )
57
61
codes_torch = torch .embedding (codebook , indices , padding_idx = - 1 ,
0 commit comments