From dd66f10b835827b1bea28e7f4c80490155a2b370 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 13:15:38 -0400 Subject: [PATCH 1/2] add propagate_error arg Signed-off-by: Kyle Sayers --- src/llmcompressor/args/dataset_arguments.py | 11 ++++ src/llmcompressor/pipelines/basic/pipeline.py | 14 +++-- .../pipelines/layer_sequential/pipeline.py | 58 +++++++++++-------- .../pipelines/sequential/pipeline.py | 31 ++++++---- 4 files changed, 75 insertions(+), 39 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 949933f97..19408db50 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -171,6 +171,7 @@ class DatasetArguments(CustomDatasetArguments): "will execute code present on the Hub on your local machine." }, ) + # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", metadata={ @@ -196,3 +197,13 @@ class DatasetArguments(CustomDatasetArguments): "definition" }, ) + propagate_error: Optional[bool] = field( + default=True, + metadata={ + "help": "A True value means that the activations used to calibrate layers " + "will reflect the error induced by the quantization/optimization of " + "previous layers of the model. A False value means that activations will " + "be the same as activations produced by the original, full precision base " + "model. Deafults to True" + }, + ) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 605358ae9..99db90877 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,3 +1,4 @@ +import contextlib from typing import TYPE_CHECKING, Union import torch @@ -10,6 +11,7 @@ from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils import calibration_forward_context, dispatch_for_generation +from llmcompressor.utils.helpers import DisableQuantization if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -42,11 +44,15 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() + # disable gradients, kv cache, ect. with calibration_forward_context(model): - for batch in tqdm.tqdm(dataloader, desc="Calibrating"): - batch = apply_pad_mask_to_batch(batch) - batch = tensors_to_device(batch, model_device) - model(**batch) + with DisableQuantization( + model + ) if not dataset_args.propagate_error else contextlib.nullcontext(): + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): + batch = apply_pad_mask_to_batch(batch) + batch = tensors_to_device(batch, model_device) + model(**batch) LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index d8ad73a10..b2ee8a910 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING import torch -import tqdm from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm from llmcompressor.core import LifecycleCallbacks, active_session from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -68,7 +68,8 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model), DisableQuantization(model): + # disable gradients, kv cache, ect. + with calibration_forward_context(model): # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( model, layers[0], dataloader @@ -83,31 +84,42 @@ def __call__( # reduce memory movement by keeping modules onloaded with disable_offloading(): # do a preliminary pass to trigger modifier hooks - for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): - inputs = intermediates.fetch(batch_idx) - layer(**inputs) + with DisableQuantization(model): + for index in tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(index) + output = layer(**inputs) - LifecycleCallbacks.sequential_epoch_end() + if not dataset_args.propagate_error: + if layer_index < num_layers - 1: + next_layer = layers[layer_index + 1] + output = to_next_layer_kwargs(output, next_layer) + output = maybe_inject_pos_embeddings( + output, next_layer, inputs + ) - # this pass does not trigger modifier hooks - # and is only used for capturing outputs from - # newly compressed modules - with HooksMixin.disable_hooks(): - for batch_idx in tqdm.tqdm( - range(len(dataloader)), desc=prop_desc - ): - inputs = intermediates.fetch(batch_idx) - output = layer(**inputs) + intermediates.delete(index) + intermediates.update(index, output) - if layer_index < num_layers - 1: - next_layer = layers[layer_index + 1] - output = to_next_layer_kwargs(output, next_layer) - output = maybe_inject_pos_embeddings( - output, next_layer, inputs - ) + # trigger layer optimization + LifecycleCallbacks.sequential_epoch_end() - intermediates.delete(batch_idx) - intermediates.update(batch_idx, output) + # this pass does not trigger modifier hooks + # and is only used for capturing outputs of newly compressed modules + if dataset_args.propagate_error: + with HooksMixin.disable_hooks(): + for index in tqdm(range(len(dataloader)), desc=prop_desc): + inputs = intermediates.fetch(index) + output = layer(**inputs) + + if layer_index < num_layers - 1: + next_layer = layers[layer_index + 1] + output = to_next_layer_kwargs(output, next_layer) + output = maybe_inject_pos_embeddings( + output, next_layer, inputs + ) + + intermediates.delete(index) + intermediates.update(index, output) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 8cefeb0cf..0ef8ba5f4 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -67,7 +67,8 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - with calibration_forward_context(model), DisableQuantization(model): + # disable gradients, kv cache, ect. + with calibration_forward_context(model): # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader) @@ -79,22 +80,28 @@ def __call__( # reduce memory movement by keeping modules onloaded with disable_offloading(): # do a preliminary pass to trigger modifier hooks - for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): - inputs = activations.fetch(batch_idx, subgraph.input_names) - subgraph.forward(model, **inputs) + with DisableQuantization(model): + for index in tqdm(range(len(dataloader)), desc=calib_desc): + inputs = activations.fetch(index, subgraph.input_names) + output = subgraph.forward(model, **inputs) + + if not dataset_args.propagate_error: + activations.update(index, output) + activations.delete(index, subgraph.consumed_names) + # trigger layer optimization LifecycleCallbacks.sequential_epoch_end() # this pass does not trigger modifier hooks # and is only used for capturing outputs of newly compressed modules - with HooksMixin.disable_hooks(): - for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc): - inputs = activations.fetch(batch_idx, subgraph.input_names) - output = subgraph.forward(model, **inputs) - - if subgraph_index < num_subgraphs - 1: - activations.update(batch_idx, output) - activations.delete(batch_idx, subgraph.consumed_names) + if dataset_args.propagate_error: + with HooksMixin.disable_hooks(): + for index in tqdm(range(len(dataloader)), desc=prop_desc): + inputs = activations.fetch(index, subgraph.input_names) + output = subgraph.forward(model, **inputs) + + activations.update(index, output) + activations.delete(index, subgraph.consumed_names) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() From b8b182c2cea64c20c1edaca7c3d5562c58e05ea9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 20 Jun 2025 17:56:45 -0400 Subject: [PATCH 2/2] remove nullable dataset args Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/basic/pipeline.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 99db90877..5991540cd 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,5 +1,5 @@ import contextlib -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch import tqdm @@ -25,7 +25,7 @@ class BasicPipeline(CalibrationPipeline): def __call__( model: torch.nn.Module, dataloader: DataLoader, - dataset_args: Union["DatasetArguments", None], + dataset_args: "DatasetArguments", ): """ Run a basic data pipeline. @@ -58,5 +58,7 @@ def __call__( def run_calibration(model: torch.nn.Module, dataloader: DataLoader): + from llmcompressor.args.dataset_arguments import DatasetArguments + pipeline = BasicPipeline() - pipeline(model, dataloader, None) + pipeline(model, dataloader, DatasetArguments())