@@ -27,17 +27,28 @@ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int
27
27
:param location: location on module
28
28
:return: size of matrix
29
29
"""
30
- assert isinstance (module , torch .nn .Linear )
31
- if location in ("input" , TransformLocation .WEIGHT_INPUT ):
32
- return module .in_features
33
- else :
34
- return module .out_features
30
+ if isinstance (module , torch .nn .Linear ):
31
+ if location in ("input" , TransformLocation .WEIGHT_INPUT ):
32
+ return module .in_features
33
+ else :
34
+ return module .out_features
35
+ elif isinstance (module , torch .nn .Embedding ):
36
+ if location in ("input" , TransformLocation .WEIGHT_INPUT ):
37
+ return module .num_embeddings
38
+ else :
39
+ return module .embedding_dim
40
+
41
+ raise ValueError (
42
+ f"Unsupported module type { type (module )} , "
43
+ "should be either Linear or Embedding."
44
+ )
35
45
36
46
37
47
def apply_transform_weight (
38
- weight : torch .Tensor ,
48
+ transform_weight : torch .Tensor ,
39
49
value : torch .Tensor ,
40
50
location : TransformLocation ,
51
+ is_linear : bool = True ,
41
52
) -> torch .Tensor :
42
53
"""
43
54
Using the transform location, determine how to apply the transform weight to the
@@ -69,23 +80,36 @@ def apply_transform_weight(
69
80
= y U
70
81
= yh
71
82
72
- :param weight: transform weight to apply
73
- :param value: value to apply weight to
74
- :param location: determines how weight should be applied
75
- :return: value after transform weight has been applied
83
+ :param transform_weight: transform weight to apply
84
+ :param value: value to apply transform_weight to
85
+ :param location: determines how transform_weight should be applied
86
+ :param is_linear: if value belongs to the weights of a Linear module
87
+ This is needed because torch uses convention:
88
+ Linear(in_features,out_features) has weight shape (out_features, in_features)
89
+ But other modules (e.g. torch.nn.Embedding) don't:
90
+ Embedding(num_embeddings, embedding_dim) has weight shape
91
+ (num_embeddings, embedding_dim)
92
+ :return: value after transform_weight has been applied
76
93
"""
77
94
78
95
if location == TransformLocation .INPUT :
79
- return value @ weight
96
+ return value @ transform_weight
80
97
81
98
elif location == TransformLocation .WEIGHT_INPUT :
82
- return value @ weight .T
99
+ if is_linear :
100
+ return value @ transform_weight .T
101
+ else :
102
+ # TODO is this ever needed?
103
+ raise NotImplementedError ()
83
104
84
105
elif location == TransformLocation .WEIGHT_OUTPUT :
85
- return weight .T @ value
106
+ if is_linear :
107
+ return transform_weight .T @ value
108
+ else :
109
+ return value @ transform_weight
86
110
87
111
elif location == TransformLocation .OUTPUT :
88
- return value @ weight
112
+ return value @ transform_weight
89
113
90
114
else :
91
115
raise NotImplementedError (f"{ location } has not been implemented yet" )
0 commit comments