Skip to content

[Transform] QuIP Modifier #1648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: bdellabe/transform-modifier
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa

from .spinquant import SpinQuantModifier
from .quip import QuIPModifier
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/quip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, wildcard imports (from ... import *) should be avoided as they make it unclear which names are present in the namespace. While base.py defines __all__, making the import safer, it's more explicit and maintainable to import the required names directly.

Suggested change
from .base import *
from .base import QuIPModifier

131 changes: 131 additions & 0 deletions src/llmcompressor/modifiers/transform/quip/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Iterable, List, Literal, Optional, Union

from compressed_tensors.transform import (
TransformArgs,
TransformConfig,
TransformScheme,
apply_transform_config,
)
from pydantic import Field, ValidationInfo, field_validator

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier

__all__ = ["QuIPModifier"]


class QuIPModifier(Modifier):
"""
Implements the transforms according to
[QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501
[QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501

Transforms (rotations) are extra layers added to a model which reduce the accuracy
loss induced by quantization. This is achived through "rotating" weights and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the docstring. "achived" should be "achieved".

Suggested change
loss induced by quantization. This is achived through "rotating" weights and
loss induced by quantization. This is achieved through "rotating" weights and

activations into a space with a smaller dynamic range of values, thus decreasing
the range of scales required for quantization.

QuIP and QuIP# apply transforms to every linear layer, two of which are fused into
the model weights and two of which remain as online rotations computed at runtime.

:param transform_type: The type of transform to apply to the model.
`"hadamard"` has the least performance cost but only supports sizes which are
powers of power of two.
`"random-matrix"` has more performance cost, but supports a much larger set of
sizes.
`"random-matrix"` has the greatest performance cost, but supports any size
Comment on lines +34 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documentation for transform_type appears to have a copy-paste error. It lists "random-matrix" twice and omits "random-hadamard", which is a valid option for transform_type. This can be confusing for users. The docstring should be updated to correctly describe all available transform types.

Suggested change
`"random-matrix"` has more performance cost, but supports a much larger set of
sizes.
`"random-matrix"` has the greatest performance cost, but supports any size
`"random-hadamard"` has more performance cost, but supports a much larger set of
sizes.
`"random-matrix"` has the greatest performance cost, but supports any size

:param randomize: If true, create distinct transforms for each application
:param learnable: If true, attach gradients to transform weights for training
:param ignore: Modules to ignore when attaching transforms
:param transform_config: Optional transform config for overriding provided arguments
"""

transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
default="hadamard", exclude=True
)
randomize: bool = Field(default=False, exclude=True)
learnable: bool = Field(default=False, exclude=True)
ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True)

# optional override for more fine-grained control
# also included in recipe serialization
transform_config: Optional[TransformConfig] = Field(default=None, repr=False)

@field_validator("randomize", "learnable", mode="before")
def validate_not_implemented(cls, value, info: ValidationInfo):
raise NotImplementedError(f"{info.field_name} is not supported right now")
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The validate_not_implemented validator has two critical issues:

  1. It raises NotImplementedError even if randomize or learnable are explicitly set to False in a recipe. The check should only trigger if they are set to True.
  2. Pydantic validators must return the validated value. This implementation implicitly returns None, which will incorrectly overwrite the field's value.

These issues make the modifier unusable if these fields are present in the recipe.

Suggested change
def validate_not_implemented(cls, value, info: ValidationInfo):
raise NotImplementedError(f"{info.field_name} is not supported right now")
def validate_not_implemented(cls, value, info: ValidationInfo):
if value:
raise NotImplementedError(f"{info.field_name} is not supported right now")
return value


def on_initialize(self, state: State, **kwargs) -> bool:
if self.transform_config is not None:
return True

self.transform_config = self._create_config()
return True
Comment on lines +58 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ignore field is a Union[str, List[str]], but the TransformArgs it's passed to likely expects a list of strings. If a single string (like the default "lm_head") is passed, it could be iterated over character by character, leading to incorrect behavior. It's safer to ensure self.ignore is always a list. This can be handled at the beginning of on_initialize.

Suggested change
def on_initialize(self, state: State, **kwargs) -> bool:
if self.transform_config is not None:
return True
self.transform_config = self._create_config()
return True
def on_initialize(self, state: State, **kwargs) -> bool:
if isinstance(self.ignore, str):
self.ignore = [self.ignore]
if self.transform_config is not None:
return True
self.transform_config = self._create_config()
return True


def on_start(self, state: State, event: Event, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The on_start method is called with None for the event parameter from within on_event, but its signature event: Event does not allow None. This violates the type hint and could cause issues with static analysis tools. The signature should be updated to event: Optional[Event] = None to reflect that the event is optional.

Suggested change
def on_start(self, state: State, event: Event, **kwargs):
def on_start(self, state: State, event: Optional[Event] = None, **kwargs):

self.started_ = True

apply_transform_config(state.model, self.transform_config)

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
pass

elif event.type_ == EventType.CALIBRATION_EPOCH_END:
if not self.ended_:
self.on_end(state, None)

def on_end(self, state: State, event: Event, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to on_start, the on_end method is called with None for the event parameter, but its signature event: Event does not allow for None. This should be corrected to event: Optional[Event] = None to align with its usage.

Suggested change
def on_end(self, state: State, event: Event, **kwargs):
def on_end(self, state: State, event: Optional[Event] = None, **kwargs):

self.ended_ = True

def on_finalize(self, state: State, **kwargs) -> bool:
if not self.ended_:
self.on_end(state, None)

return True

def _create_config(self) -> TransformConfig:
return TransformConfig(
config_groups={
"v": TransformScheme(
type=self.transform_type,
apply=[
TransformArgs(
targets=["Linear"],
location="input", # non-mergable
ignore=self.ignore,
),
TransformArgs(
targets=["Linear"],
location="weight_input",
inverse=True,
ignore=self.ignore,
),
],
randomize=self.randomize,
requires_grad=self.learnable,
),
"u": TransformScheme(
type=self.transform_type,
apply=[
TransformArgs(
targets=["Linear"],
location="weight_output",
ignore=self.ignore,
),
TransformArgs(
targets=["Linear"],
location="output", # non-mergable
inverse=True,
ignore=self.ignore,
),
],
randomize=self.randomize,
requires_grad=self.learnable,
),
}
)
Loading