Skip to content

[Transform] Support applying transforms to embeddings #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: kylesayrs/transform-attention-head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,8 @@ def input_hook(_, args):
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
):
assert isinstance(module, torch.nn.Linear)
assert module.bias is None

# fuse transform into weight
assert hasattr(module, "weight")
with torch.no_grad(), align_module_device(module):
update_offload_parameter(module, "weight", transform(module.weight))

Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs):
:param module: parent module that transform will be applied to
:param args: defines how the transform will be applied to the module
"""
assert isinstance(module, Linear)
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
device = get_offloaded_device(module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_transform(self, module: Module, args: TransformArgs):
:param module: parent module that transform will be applied to
:param args: defines how the transform will be applied to the module
"""
assert isinstance(module, Linear)
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
device = get_offloaded_device(module)
Expand Down
30 changes: 29 additions & 1 deletion src/compressed_tensors/transform/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def get_transform_size(
size = module.in_features
else:
size = module.out_features
elif isinstance(module, torch.nn.Embedding):
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
size = module.num_embeddings
else:
size = module.embedding_dim
else:
raise NotImplementedError(f"Transforms on {type(module)} are not supported")

Expand All @@ -64,7 +69,12 @@ def apply_transform_weight(
:param value: value to apply weight to
:param location: determines how weight should be applied
:param model_type: result of type(module), passed in to determine application of
weight transform
weight transform. This is needed because torch uses convention:
- torch.nn.Linear(in_features,out_features) has weight shape
(out_features, in_features)
- torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape
(num_embeddings, embedding_dim)
The transform has to account for Linear's transposed weights
:return: value after weight has been applied
"""
fn, axis = _get_transform_method(module_type, location)
Expand Down Expand Up @@ -139,6 +149,24 @@ def _get_transform_method(
fn = lambda weight, value: value @ weight
axis = -1

# similar derivation to torch.nn.Linear, but `y = (x W)`
if module_type == torch.nn.Embedding:
if location == TransformLocation.INPUT:
fn = lambda weight, value: value @ weight
axis = -1

elif location == TransformLocation.WEIGHT_INPUT:
fn = lambda weight, value: weight @ value
axis = -1

elif location == TransformLocation.WEIGHT_OUTPUT:
fn = lambda weight, value: value @ weight
axis = -1

elif location == TransformLocation.OUTPUT:
fn = lambda weight, value: value @ weight
axis = -1

if fn is None:
raise NotImplementedError(
f"Applying transforms to {module_type} {location} is not supported"
Expand Down
34 changes: 33 additions & 1 deletion tests/test_transform/factory/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
@pytest.mark.parametrize("head_dim", (None, 2, 4))
def test_correctness_linear(type, randomized, head_dim):
size = (4, 8)
module = torch.nn.Linear(*size, bias=True)
module = torch.nn.Linear(*size, bias=False)
scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim)
factory = TransformFactory.from_scheme(scheme, name="")

Expand All @@ -56,6 +56,38 @@ def test_correctness_linear(type, randomized, head_dim):
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)


@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@pytest.mark.parametrize("randomized", (True, False))
@pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
@pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
model = torch.nn.Sequential(
torch.nn.Embedding(2, 4),
torch.nn.Linear(4, 8, bias=False),
)

input = torch.randint(high=1, low=0, size=(17, 5, 2))
true_output = model(input)

config = TransformConfig(
config_groups={
"": TransformScheme(
type=type,
randomized=randomized,
apply=[
TransformArgs(targets="Embedding", location=embed_loc),
TransformArgs(targets="Linear", location=linear_loc, inverse=True),
],
)
}
)
apply_transform_config(model, config)

# compare outputs
output = model(input)
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)


@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@pytest.mark.parametrize("randomized", (True, False))
def test_correctness_model(type, randomized, model_apply, offload=False):
Expand Down