From 0b4fdb37d5fce5a13d14ef465ad0646215577346 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 9 Jul 2025 17:18:00 -0400 Subject: [PATCH] support embeddings Signed-off-by: Kyle Sayers --- .../transform/factory/base.py | 4 +-- .../transform/factory/hadamard.py | 2 +- .../transform/factory/matrix_multiply.py | 2 +- .../transform/utils/matrix.py | 30 +++++++++++++++- .../factory/test_correctness.py | 34 ++++++++++++++++++- 5 files changed, 65 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index e5a1e05c..de46385c 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -117,10 +117,8 @@ def input_hook(_, args): TransformLocation.WEIGHT_INPUT, TransformLocation.WEIGHT_OUTPUT, ): - assert isinstance(module, torch.nn.Linear) - assert module.bias is None - # fuse transform into weight + assert hasattr(module, "weight") with torch.no_grad(), align_module_device(module): update_offload_parameter(module, "weight", transform(module.weight)) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index f55089b2..b000619e 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert isinstance(module, Linear) + assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = module.weight.dtype device = get_offloaded_device(module) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 7279befc..8b829451 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -50,7 +50,7 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert isinstance(module, Linear) + assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = module.weight.dtype device = get_offloaded_device(module) diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index d3f95fd0..37419843 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -39,6 +39,11 @@ def get_transform_size( size = module.in_features else: size = module.out_features + elif isinstance(module, torch.nn.Embedding): + if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): + size = module.num_embeddings + else: + size = module.embedding_dim else: raise NotImplementedError(f"Transforms on {type(module)} are not supported") @@ -64,7 +69,12 @@ def apply_transform_weight( :param value: value to apply weight to :param location: determines how weight should be applied :param model_type: result of type(module), passed in to determine application of - weight transform + weight transform. This is needed because torch uses convention: + - torch.nn.Linear(in_features,out_features) has weight shape + (out_features, in_features) + - torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape + (num_embeddings, embedding_dim) + The transform has to account for Linear's transposed weights :return: value after weight has been applied """ fn, axis = _get_transform_method(module_type, location) @@ -139,6 +149,24 @@ def _get_transform_method( fn = lambda weight, value: value @ weight axis = -1 + # similar derivation to torch.nn.Linear, but `y = (x W)` + if module_type == torch.nn.Embedding: + if location == TransformLocation.INPUT: + fn = lambda weight, value: value @ weight + axis = -1 + + elif location == TransformLocation.WEIGHT_INPUT: + fn = lambda weight, value: weight @ value + axis = -1 + + elif location == TransformLocation.WEIGHT_OUTPUT: + fn = lambda weight, value: value @ weight + axis = -1 + + elif location == TransformLocation.OUTPUT: + fn = lambda weight, value: value @ weight + axis = -1 + if fn is None: raise NotImplementedError( f"Applying transforms to {module_type} {location} is not supported" diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index b9f18f0c..acf5ba47 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("head_dim", (None, 2, 4)) def test_correctness_linear(type, randomized, head_dim): size = (4, 8) - module = torch.nn.Linear(*size, bias=True) + module = torch.nn.Linear(*size, bias=False) scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim) factory = TransformFactory.from_scheme(scheme, name="") @@ -56,6 +56,38 @@ def test_correctness_linear(type, randomized, head_dim): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomized", (True, False)) +@pytest.mark.parametrize("embed_loc", ("weight_output", "output")) +@pytest.mark.parametrize("linear_loc", ("input", "weight_input")) +def test_correctness_embedding(type, randomized, embed_loc, linear_loc): + model = torch.nn.Sequential( + torch.nn.Embedding(2, 4), + torch.nn.Linear(4, 8, bias=False), + ) + + input = torch.randint(high=1, low=0, size=(17, 5, 2)) + true_output = model(input) + + config = TransformConfig( + config_groups={ + "": TransformScheme( + type=type, + randomized=randomized, + apply=[ + TransformArgs(targets="Embedding", location=embed_loc), + TransformArgs(targets="Linear", location=linear_loc, inverse=True), + ], + ) + } + ) + apply_transform_config(model, config) + + # compare outputs + output = model(input) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) + + @pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) @pytest.mark.parametrize("randomized", (True, False)) def test_correctness_model(type, randomized, model_apply, offload=False):