Skip to content

Commit 9c9f4aa

Browse files
utils to apply transforms to torch.nn.Embedding modules
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 85f40b5 commit 9c9f4aa

File tree

4 files changed

+53
-21
lines changed

4 files changed

+53
-21
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def input_hook(_, args):
115115
TransformLocation.WEIGHT_INPUT,
116116
TransformLocation.WEIGHT_OUTPUT,
117117
):
118-
assert isinstance(module, torch.nn.Linear)
119-
assert module.bias is None
118+
assert isinstance(module, (torch.nn.Linear, torch.nn.Embedding))
119+
assert not hasattr(module, "bias") or module.bias is None
120120

121121
with torch.no_grad(), align_module_device(module):
122122
update_offload_parameter(module, "weight", transform(module.weight))

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,16 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param module: parent module that transform will be applied to
5252
:param args: defines how the transform will be applied to the module
5353
"""
54-
assert isinstance(module, Linear)
54+
is_linear = isinstance(module, Linear)
55+
assert hasattr(module, "weight")
5556
size = get_matrix_size(module, args.location)
5657
dtype = module.weight.dtype
5758
device = get_offloaded_device(module)
5859
exec_device = get_execution_device(module)
5960

6061
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
6162
perm = self.perms[weight] if self.scheme.randomize else None
62-
return HadamardTransform(weight, perm, args)
63+
return HadamardTransform(weight, perm, args, is_linear)
6364

6465
def _create_weight(
6566
self,
@@ -80,12 +81,17 @@ def _create_permutation(self, weight: Parameter) -> Parameter:
8081

8182
class HadamardTransform(TransformBase):
8283
def __init__(
83-
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
84+
self,
85+
weight: Parameter,
86+
perm: Union[Parameter, None],
87+
args: TransformArgs,
88+
is_linear: bool,
8489
):
8590
super().__init__()
8691
self.weight = weight
8792
self.perm = perm
8893
self.args = args
94+
self.is_linear = is_linear
8995

9096
def forward(self, value: Tensor) -> Tensor:
9197
weight = self.weight
@@ -96,4 +102,4 @@ def forward(self, value: Tensor) -> Tensor:
96102
if self.args.inverse:
97103
weight = weight.T
98104

99-
return apply_transform_weight(weight, value, self.args.location)
105+
return apply_transform_weight(weight, value, self.args.location, self.is_linear)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def deterministic_hadamard_matrix(
5151

5252
log2 = int(math.log2(size))
5353
if size != 2**log2:
54-
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
54+
raise ValueError(
55+
f"Cannot construct deterministic hadamard of size {size} != 2^n"
56+
)
5557

5658
H = torch.tensor([[1]], dtype=dtype, device=device)
5759

src/compressed_tensors/transform/utils/utils.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,28 @@ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int
2727
:param location: location on module
2828
:return: size of matrix
2929
"""
30-
assert isinstance(module, torch.nn.Linear)
31-
if location in ("input", TransformLocation.WEIGHT_INPUT):
32-
return module.in_features
33-
else:
34-
return module.out_features
30+
if isinstance(module, torch.nn.Linear):
31+
if location in ("input", TransformLocation.WEIGHT_INPUT):
32+
return module.in_features
33+
else:
34+
return module.out_features
35+
elif isinstance(module, torch.nn.Embedding):
36+
if location in ("input", TransformLocation.WEIGHT_INPUT):
37+
return module.num_embeddings
38+
else:
39+
return module.embedding_dim
40+
41+
raise ValueError(
42+
f"Unsupported module type {type(module)}, "
43+
"should be either Linear or Embedding."
44+
)
3545

3646

3747
def apply_transform_weight(
38-
weight: torch.Tensor,
48+
transform_weight: torch.Tensor,
3949
value: torch.Tensor,
4050
location: TransformLocation,
51+
is_linear: bool = True,
4152
) -> torch.Tensor:
4253
"""
4354
Using the transform location, determine how to apply the transform weight to the
@@ -69,23 +80,36 @@ def apply_transform_weight(
6980
= y U
7081
= yh
7182
72-
:param weight: transform weight to apply
73-
:param value: value to apply weight to
74-
:param location: determines how weight should be applied
75-
:return: value after transform weight has been applied
83+
:param transform_weight: transform weight to apply
84+
:param value: value to apply transform_weight to
85+
:param location: determines how transform_weight should be applied
86+
:param is_linear: if value belongs to the weights of a Linear module
87+
This is needed because torch uses convention:
88+
Linear(in_features,out_features) has weight shape (out_features, in_features)
89+
But other modules (e.g. torch.nn.Embedding) don't:
90+
Embedding(num_embeddings, embedding_dim) has weight shape
91+
(num_embeddings, embedding_dim)
92+
:return: value after transform_weight has been applied
7693
"""
7794

7895
if location == TransformLocation.INPUT:
79-
return value @ weight
96+
return value @ transform_weight
8097

8198
elif location == TransformLocation.WEIGHT_INPUT:
82-
return value @ weight.T
99+
if is_linear:
100+
return value @ transform_weight.T
101+
else:
102+
# TODO is this ever needed?
103+
raise NotImplementedError()
83104

84105
elif location == TransformLocation.WEIGHT_OUTPUT:
85-
return weight.T @ value
106+
if is_linear:
107+
return transform_weight.T @ value
108+
else:
109+
return value @ transform_weight
86110

87111
elif location == TransformLocation.OUTPUT:
88-
return value @ weight
112+
return value @ transform_weight
89113

90114
else:
91115
raise NotImplementedError(f"{location} has not been implemented yet")

0 commit comments

Comments
 (0)