1
+ import torch
2
+ from torch .autograd import Function
3
+
4
+ class VectorQuantization (Function ):
5
+ @staticmethod
6
+ def forward (ctx , inputs , codebook ):
7
+ with torch .no_grad ():
8
+ embedding_size = codebook .size (1 )
9
+ inputs_size = inputs .size ()
10
+ inputs_flatten = inputs .view (- 1 , embedding_size )
11
+
12
+ codebook_sqr = torch .sum (codebook ** 2 , dim = 1 )
13
+ inputs_sqr = torch .sum (inputs_flatten ** 2 , dim = 1 , keepdim = True )
14
+
15
+ # Compute the distances to the codebook
16
+ distances = torch .addmm (codebook_sqr + inputs_sqr ,
17
+ inputs_flatten , codebook .t (), alpha = - 2.0 , beta = 1.0 )
18
+
19
+ _ , indices_flatten = torch .min (distances , dim = 1 )
20
+ indices = indices_flatten .view (* inputs_size [:- 1 ])
21
+ ctx .mark_non_differentiable (indices )
22
+
23
+ return indices
24
+
25
+ @staticmethod
26
+ def backward (ctx , grad_output ):
27
+ raise RuntimeError ('Trying to call `.grad()` on graph containing '
28
+ '`VectorQuantization`. The function `VectorQuantization` '
29
+ 'is not differentiable. Use `VectorQuantizationStraightThrough` '
30
+ 'if you want a straight-through estimator of the gradient.' )
31
+
32
+ class VectorQuantizationStraightThrough (Function ):
33
+ @staticmethod
34
+ def forward (ctx , inputs , codebook ):
35
+ indices = vq (inputs , codebook )
36
+ indices_flatten = indices .view (- 1 )
37
+ ctx .save_for_backward (indices_flatten , codebook )
38
+
39
+ codes_flatten = torch .index_select (codebook , dim = 0 ,
40
+ index = indices_flatten )
41
+ codes = codes_flatten .view_as (inputs )
42
+
43
+ return codes
44
+
45
+ @staticmethod
46
+ def backward (ctx , grad_output ):
47
+ grad_inputs , grad_codebook = None , None
48
+
49
+ if ctx .needs_input_grad [0 ]:
50
+ # Straight-through estimator
51
+ grad_inputs = grad_output .clone ()
52
+ if ctx .needs_input_grad [1 ]:
53
+ # Gradient wrt. the codebook
54
+ indices , codebook = ctx .saved_tensors
55
+ embedding_size = codebook .size (1 )
56
+
57
+ grad_output_flatten = grad_output .view (- 1 , embedding_size )
58
+ grad_codebook = torch .zeros_like (codebook )
59
+ grad_codebook .index_add_ (0 , indices , grad_output_flatten )
60
+
61
+ return (grad_inputs , grad_codebook )
0 commit comments