diff --git a/compress_model.py b/compress_model.py new file mode 100644 index 000000000..fa67bead0 --- /dev/null +++ b/compress_model.py @@ -0,0 +1,60 @@ +# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, help="Model stub to compress") + parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") + parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Select model and load it. + MODEL_ID = args.model_id + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + # Select number of samples. 512 samples is a good place to start. + # Increasing the number of samples can improve accuracy. + NUM_CALIBRATION_SAMPLES = 512 + MAX_SEQUENCE_LENGTH = 2048 + + # Configure the quantization algorithm to run. + recipe = [] + if args.transform_type: + recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) + + if args.scheme: + recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) + + # Apply algorithms. + oneshot( + model=model, + recipe=recipe, + dataset="ultrachat_200k", + splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + ) + + # Confirm generations of the quantized model look sane. + print("\n\n") + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) + print("==========================================\n\n") + + # Save to disk compressed. + SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" + model.save_pretrained(SAVE_DIR, save_compressed=True) + tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py new file mode 100644 index 000000000..876db7138 --- /dev/null +++ b/examples/transform/spinquant_example.py @@ -0,0 +1,86 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot( + model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index e2c22ed1f..76b6b0391 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa +from .fuse import * from .prepare import * diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index a8168a31f..e59be596c 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -6,55 +6,58 @@ get_execution_device, update_offload_parameter, ) +from transformers.models.llama.modeling_llama import LlamaRMSNorm -__all__ = ["center_embeddings", "fuse_norm_linears"] +__all__ = ["normalize_embedding", "fuse_norm_linears"] PRECISION = torch.float64 -def center_embeddings(embedding: torch.nn.Module): +def normalize_embedding(embedding: torch.nn.Module): """ - Shift each embedding to have a mean of zero + Normalize each embedding to have a mean of zero :param embedding: embedding module containing embeddings to center """ - if not hasattr(embedding, "weight"): - raise ValueError(f"Cannot fuse norm of type {type(embedding)}") + if isinstance(embedding, (torch.nn.Embedding)): + with align_module_device(embedding): + weight_dtype = embedding.weight.dtype + weight = embedding.weight.to(PRECISION) + new_weight = weight - weight.mean(dim=-1, keepdim=True) + new_weight = new_weight.to(weight_dtype) - with align_module_device(embedding): - weight_dtype = embedding.weight.dtype - weight = embedding.weight.to(PRECISION) - new_weight = weight - weight.mean(dim=-1, keepdim=True) - new_weight = new_weight.to(weight_dtype) + update_offload_parameter(embedding, "weight", new_weight) - update_offload_parameter(embedding, "weight", new_weight) + else: + raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): """ - Fuse the scaling operation of norm layer into subsequent linear layers. - This useful for ensuring transform invariance between norm and linear layers. + Fuse a norm layer into subsequent linear layers. This useful for ensuring transform + invariance between norm and linear layers. - Note that unitary transforms (rotation) commute with normalization, but not scaling + Note that a model cannot be properly trained after its norms have been fused :param norm: norm layer whose weight will be fused into subsequent linears :param linears: linear layers which directly follow the norm layer """ - if not hasattr(norm, "weight"): + if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm, torch.nn.LayerNorm)): + for linear in linears: + # NOTE: spinquant does this op in float64 + exec_device = get_execution_device(norm) + with align_module_device(norm, exec_device), align_module_device( + linear, exec_device + ): + weight_dtype = linear.weight.dtype + new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) + new_weight = new_weight.to(weight_dtype) + + update_offload_parameter(linear, "weight", new_weight) + + new_norm_weight = torch.ones_like(norm.weight, device="cpu") + update_offload_parameter(norm, "weight", new_norm_weight) + + else: raise ValueError(f"Cannot fuse norm of type {type(norm)}") - - for linear in linears: - # NOTE: spinquant does this op in float64 - exec_device = get_execution_device(norm) - with align_module_device(norm, exec_device), align_module_device( - linear, exec_device - ): - weight_dtype = linear.weight.dtype - new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) - new_weight = new_weight.to(weight_dtype) - - update_offload_parameter(linear, "weight", new_weight) - - new_norm_weight = torch.ones_like(norm.weight, device="cpu") - update_offload_parameter(norm, "weight", new_norm_weight) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..9956d0340 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py new file mode 100644 index 000000000..5978b93ea --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -0,0 +1,221 @@ +from enum import Enum +from typing import Iterable, List, Literal, Optional + +from compressed_tensors import match_modules_set, match_named_modules +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from pydantic import Field, ValidationInfo, field_validator +from transformers import PreTrainedModel + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import fuse_norm_linears, normalize_embedding +from llmcompressor.modifiers import Modifier + +from .mappings import SpinQuantMapping, infer_mapping_from_model +from .norm_mappings import NormMapping, infer_norm_mapping_from_model + + +class SpinquantRotation(str, Enum): + R1 = "R1" + R2 = "R2" + R3 = "R3" + R4 = "R4" + + +class SpinQuantModifier(Modifier, use_enum_values=True): + """ + Implements the transforms according to + [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + The SpinQuant authors describe four different rotations which can be applied to a + model. R1 and R2 are "offline" rotations, meaning that they can be fused into + existing weights and therefore do not induce runtime cost. R3 and R4 are "online" + rotations, meaning that they require additional computation at runtime. + + :param rotations: A list containing the names of rotations to apply to the model. + Possible rotations include R1, R2, R3, and R4 + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: if True, create distinct transforms for each application + :param learnable: if True, attach gradients to transform weights for training + :param mappings: Specifies layers within a model to target for transforms. + A mapping will be inferred if None is provided + :param norm_mappings: Specifies layers within a model to target for norm fusing. + A mapping will be inferred if None is provided + :param transform_config: Optional transform config for overriding provided arguments + """ + + rotations: List[SpinquantRotation] = Field( + default_factory=lambda: ["R1", "R2"], exclude=True + ) + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard", exclude=True + ) + randomize: bool = Field(default=False, exclude=True) + learnable: bool = Field(default=False, exclude=True) + + # norm mappings separate from spinquant mappings to allow users to + # override spinquant mappings with transform_config without overriding norms + mappings: Optional[SpinQuantMapping] = Field( + default=None, + repr=False, + exclude=True, + ) + norm_mappings: Optional[List[NormMapping]] = Field( + default=None, + repr=False, + exclude=True, + ) + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + raise NotImplementedError(f"{info.field_name} is not supported right now") + + @field_validator("rotations", mode="before") + def validate_rotations(cls, value): + if isinstance(value, Iterable): + return tuple(v.upper() for v in value) + return value + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) + + config_groups = {} + if SpinquantRotation.R1 in self.rotations: + config_groups["R1"] = self._create_r1_scheme() + + if SpinquantRotation.R2 in self.rotations: + config_groups["R2"] = self._create_r2_scheme(state.model) + + if SpinquantRotation.R3 in self.rotations: + config_groups["R3"] = self._create_r3_scheme() + + if SpinquantRotation.R4 in self.rotations: + config_groups["R4"] = self._create_r4_scheme() + + self.transform_config = TransformConfig(config_groups=config_groups) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + self._prenormalize_embeddings(state.model) + self._fuse_norms(state.model) + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _prenormalize_embeddings(self, model: PreTrainedModel): + for _, embedding in match_named_modules( + model, [self.mappings.embedding], warn_on_fail=True + ): + normalize_embedding(embedding) + + def _fuse_norms(self, model: PreTrainedModel): + for mapping in self.norm_mappings: + for norm, *linears in match_modules_set( + model, (mapping.norm, *mapping.linears) + ): + fuse_norm_linears(norm, linears) + + def _create_r1_scheme(self) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[ + self.mappings.embedding, + self.mappings.attn_o, + *self.mappings.mlp_out, + ], + location="weight_output", + ), + TransformArgs( + targets=[ + self.mappings.attn_q, + self.mappings.attn_k, + self.mappings.attn_v, + *self.mappings.mlp_in, + self.mappings.lm_head, + ], + location="weight_input", + inverse=True, + ), + ], + ) + + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), + TransformArgs( + targets=[self.mappings.attn_o], + location="weight_input", + inverse=True, + ), + ], + ) + + def _create_r3_scheme(self) -> TransformScheme: + raise NotImplementedError() + + def _create_r4_scheme(self) -> TransformScheme: + raise NotImplementedError() diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py new file mode 100644 index 000000000..7dc327b78 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -0,0 +1,57 @@ +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel + +__all__ = ["SpinQuantMapping", "infer_mapping_from_model"] + + +class SpinQuantMapping(BaseModel): + embedding: str + + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = SpinQuantMapping( + embedding="re:.*embed_tokens$", + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + lm_head="lm_head", +) + + +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { + "LlamaForCausalLM": _default_mappings, +} + + +def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping: + architecture = model.__class__.__name__ + if architecture not in SPINQUANT_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py new file mode 100644 index 000000000..0752f6986 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -0,0 +1,50 @@ +from typing import Dict, List + +from loguru import logger +from pydantic import BaseModel, field_validator +from transformers import PreTrainedModel + +__all__ = ["infer_norm_mapping_from_model"] + + +class NormMapping(BaseModel): + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), +] + +NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { + "LlamaForCausalLM": _default_mappings, +} + + +def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: + architecture = model.__class__.__name__ + if architecture not in NORM_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index 587f7ca69..7ad6d56dc 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -5,6 +5,7 @@ from llmcompressor.core.session_functions import LifecycleCallbacks from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.dev import dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,5 +28,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + # some ops are still performed on the model by modifiers + # we want those ops to occur on the GPU + dispatch_for_generation(model) + LifecycleCallbacks.calibration_epoch_start() LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 2c1a54cf5..67d510d13 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -8,6 +8,7 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -61,4 +62,8 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: if not config.requires_calibration_data(): return "datafree" + # TODO: Remove hardcode + if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): + return "datafree" + return "sequential" diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index 005d89f99..5798f692c 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -1,13 +1,13 @@ import pytest import torch -from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears +from llmcompressor.modeling.fuse import fuse_norm_linears, normalize_embedding @pytest.mark.unit -def test_center_embeddings(): +def test_normalize_embedding(): embedding = torch.nn.Embedding(10, 10) - center_embeddings(embedding) + normalize_embedding(embedding) assert torch.allclose( embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5 diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..660bab0ef --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,34 @@ +import pytest +import torch +from compressed_tensors.transform import apply_transform_config +from transformers import AutoModelForCausalLM + +from llmcompressor.modifiers.transform.template.quip import QUIP + + +@pytest.mark.parametrize( + "dtype,exp_max,exp_mse", + [ + ( + torch.bfloat16, + 1.1, + 0.012, + ), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9), + ], +) +def test_apply_correctness(dtype, exp_max, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype + ) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + apply_transform_config(model, QUIP) + with torch.no_grad(): + output = model(**input) + + assert torch.max(true_output.logits - output.logits) <= exp_max + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse