Skip to content

[Performance] Add memory compression and decompression pathways #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2025

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Apr 16, 2025

Purpose

  • Reduce memory requirements when compressing a model

Memory Visualization

Compression Memory Improvement
Format State dict compression Model compression
Quantized
Sparse
Sparse24
Stacked
Model Compression and Decompression
Format Model Compression + Decompression
Quantized
Sparse
Sparse24
Stacked
Demonstration Script
import torch
from pttp import TensorProfiler
from transformers import AutoModelForCausalLM, AutoConfig
from compressed_tensors.compressors import ModelCompressor

for name, model_stub, comp_stub in [
    (
        "quantized-only",
        "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
        "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
    ),
    (
        "sparse-only",
        "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
        "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
    ),
    (
        "stacked",
        "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
        "nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
    )
]:
    from transformers.utils.quantization_config import CompressedTensorsConfig

    with TensorProfiler() as prof:
        prof.mark_event("Start load")
        config = AutoConfig.from_pretrained(model_stub)
        config.tie_word_embeddings = False
        model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32, device_map="cuda:0", config=config)
        compressor = ModelCompressor.from_pretrained(comp_stub)

        prof.mark_event("Start compress")
        compressor.compress_model(model)

        prof.mark_event("Start decompress")
        compressor.decompress_model(model)

    prof.save_memory_timeline(f"cdc_{name}.png")

Prerequisites

Changes

  • Implement compress_model and decompress_model, which both act on a model in memory rather than a state dict or model on disk
    • Because compress_model compresses each module independently, implement show_progress on compress methods to squelch tqdm prints for each module compression
    • Implement decompress_from_state_dict for sparsity compressors
    • Extend get_nested_mappings_from_state_dict to support returning unmatched params, similar to get_nested_weight_mappings
      • I personally dislike this usage, as I think it leads to multiple sources of truth as to which modules should be compressed. IMO, a module should be (de)compressed if and only if it is listed in the config. This function is used to create another source of truth, which is that a module should be compressed if and only if it has the relevant compression params. Currently, we use both, which means taking the intersection.
  • Misc
    • Fix bug on decompress_from_state_dict where scheme was gotten instead of weight args
    • Clarify name where variable weight_name was referring to a module path, not a weight name
    • Change sparse24 decompressor behavior where the weight was being moved to an arbitrary cuda device in fp8 cases. This violates the assumption that all ops are performed on the cpu
    • Remove remove_suffix util which can be replaced with str.removesuffix as of python3.9+ (which is the minimum we support, double check with @dsikka @rahul-tuli
    • Use get_execution_device when initing params for CompressedLinear

Testing

  • Added test_compress_model which tests that memory compression is equivalent to dict compression
  • Added test_decompress_model which tests that hfquantizer decompression (from disk) is equivalent to decompression from memory

@kylesayrs kylesayrs changed the title [WIP]: Simplify map_module_to_scheme [WIP]: Reduce memory requirements Apr 21, 2025
@kylesayrs kylesayrs changed the base branch from main to kylesayrs/map_module_to_scheme April 22, 2025 20:22
@kylesayrs kylesayrs changed the title [WIP]: Reduce memory requirements [Performance] Reduce compression memory requirements via structure change Apr 22, 2025
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ready for review? Still draft

@kylesayrs kylesayrs marked this pull request as ready for review April 23, 2025 19:32
@rahul-tuli
Copy link
Member

Could we add a test to compress a model with sparsity+quantization?

Base automatically changed from kylesayrs/map_module_to_scheme to main April 28, 2025 15:16
@kylesayrs kylesayrs requested a review from dsikka April 28, 2025 16:00
@rahul-tuli
Copy link
Member

LGTM pending conflict, good work!

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good pending verification that sparse only models can be compressed using these changes!

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to use CompressedLinear for compression? What about if we’re compressing something that isn’t a linear layer?

@kylesayrs kylesayrs marked this pull request as draft May 5, 2025 15:20
@kylesayrs kylesayrs changed the title [Performance] Reduce compression memory requirements via structure change [Performance] Add memory compression and decompression pathways May 7, 2025
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs force-pushed the kylesayrs/reduce-quantized-compression-memory branch from 0e9544d to b2cad7e Compare May 8, 2025 01:38
@kylesayrs kylesayrs marked this pull request as ready for review May 8, 2025 02:05
@dsikka
Copy link
Collaborator

dsikka commented May 8, 2025

Looks good pending verification that sparse only models can be compressed using these changes!

Can you share sparse + fp8 models recipes where we have non-uniform sparsity and/or quantization cases?
@rahul-tuli

cc @kylesayrs

@kylesayrs kylesayrs marked this pull request as draft May 8, 2025 14:06
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs marked this pull request as ready for review May 8, 2025 14:16
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's beautiful Kyle 🥇 . Love the detailed summary and charts showing the improvement

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs requested review from rahul-tuli and dsikka May 12, 2025 17:06
Copy link
Contributor

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kylesayrs kylesayrs merged commit f192f68 into main May 14, 2025
1 check passed
@kylesayrs kylesayrs deleted the kylesayrs/reduce-quantized-compression-memory branch May 14, 2025 13:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants