diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 6ec0c192..cb57c237 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -47,6 +47,9 @@ iter_named_leaf_modules, ) from compressed_tensors.utils import ( + align_module_device, + delete_offload_parameter, + get_execution_device, get_safetensors_folder, has_offloaded_params, merge_names, @@ -98,6 +101,9 @@ class ModelCompressor: :param quantization_config: config specifying quantization compression parameters """ + sparsity_config: Optional[SparsityCompressionConfig] = None + quantization_config: Optional[QuantizationConfig] = None + @classmethod def from_pretrained( cls, @@ -261,6 +267,8 @@ def __init__( quantization_config.format, config=quantization_config ) + # ----- used by hf quantizer ----- # + def get_missing_module_keys(self, model: Module) -> List[str]: """ Identifies the expected missing weight keys in the compressed state_dict. @@ -270,7 +278,6 @@ def get_missing_module_keys(self, model: Module) -> List[str]: This function determines which weight keys are missing based on the applied compression techniques. - :param model: The PyTorch model to check for missing keys. :return: A list of missing keys expected in the compressed state_dict. """ @@ -362,8 +369,124 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: return list(unexpected_keys) + # ----- model memory compression/decompression pathways ----- # + + def compress_model(self, model: Module): + """ + Compress a model in memory. Because the model structure is modified in place, + this method is more memory-efficient than `self.compress` + + :param model: model containing parameters to compress + """ + module_to_scheme = map_module_to_scheme(model) + sparse_compression_targets: Set[str] = expand_target_names( + model=model, + targets=self.sparsity_config.targets if self.sparsity_config else [], + ignore=self.sparsity_config.ignore if self.sparsity_config else [], + ) + + for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): + if prefix in module_to_scheme or prefix in sparse_compression_targets: + # in the future, support compression on same device + with align_module_device(module, execution_device="cpu"): + state_dict = module.state_dict(prefix=f"{prefix}.") + + # quantization first + if prefix in module_to_scheme: + state_dict = self.quantization_compressor.compress( + state_dict, + names_to_scheme=module_to_scheme, + show_progress=False, + ) + + # sparsity second + if prefix in sparse_compression_targets: + state_dict = self.sparsity_compressor.compress( + state_dict, + compression_targets=sparse_compression_targets, + show_progress=False, + ) + + # remove any existing parameters + device = get_execution_device(module) + for name, _ in list(module.named_parameters()): + delattr(module, name) + + # replace with compressed parameters + for name, value in state_dict.items(): + name = name.removeprefix(f"{prefix}.") + value = value.to(device) + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param) + + module.quantization_status = QuantizationStatus.COMPRESSED + + def decompress_model(self, model: Module): + """ + Decompress a model in memory. Because the model structure is modified in place, + this method does not require loading some compression parameters from disk + + :param model: model containing parameters to compress + """ + module_to_scheme = map_module_to_scheme(model) + sparse_compression_targets: Set[str] = expand_target_names( + model=model, + targets=self.sparsity_config.targets if self.sparsity_config else [], + ignore=self.sparsity_config.ignore if self.sparsity_config else [], + ) + + for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"): + if prefix in module_to_scheme or prefix in sparse_compression_targets: + # in the future, support decompression on same device + with align_module_device(module, execution_device="cpu"): + state_dict = module.state_dict(prefix=f"{prefix}.") + + # sparsity first + if prefix in sparse_compression_targets: + # sparse_compression_targets are automatically inferred by this fn + generator = self.sparsity_compressor.decompress_from_state_dict( + state_dict, + ) + # generates (param_path, param_val) + # of compressed and unused params + state_dict = {key: value for key, value in generator} + + # quantization second + if prefix in module_to_scheme: + generator = self.quantization_compressor.decompress_from_state_dict( + state_dict, + names_to_scheme=module_to_scheme, + ) + # generates (mod_path, {param_name, param_val}) + # of compressed params and used params, but not unused params + # some used params are removed by get_unexpected_file_keys + state_dict = { + merge_names(module_path, param_name): param_value + for module_path, compressed_data in generator + for param_name, param_value in compressed_data.items() + } + + # remove any existing parameters + device = get_execution_device(module) + for name, _ in list(module.named_parameters()): + delete_offload_parameter(module, name) + + # replace with decompressed parameters + for name, value in state_dict.items(): + name = name.removeprefix(f"{prefix}.") + value = value.to(device) + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param) + + module.quantization_status = QuantizationStatus.FROZEN + + # ----- state dict compression pathways ----- # + def compress( - self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None + self, + model: Module, + state_dict: Optional[Dict[str, Tensor]] = None, + show_progress: bool = False, ) -> Dict[str, Tensor]: """ Compresses a dense state dict or model with sparsity and/or quantization @@ -379,7 +502,9 @@ def compress( if self.quantization_compressor is not None: module_to_scheme = map_module_to_scheme(model) state_dict = self.quantization_compressor.compress( - state_dict, names_to_scheme=module_to_scheme + state_dict, + names_to_scheme=module_to_scheme, + show_progress=show_progress, ) # TODO: consider sparse compression to also be compression @@ -397,6 +522,7 @@ def compress( state_dict = self.sparsity_compressor.compress( state_dict, compression_targets=sparse_compression_targets, + show_progress=show_progress, ) # HACK: Override the dtype_byte_size function in transformers to @@ -406,6 +532,8 @@ def compress( return state_dict + # ----- disk decompression pathways ----- # + def decompress(self, model_path: str, model: Module): """ Overwrites the weights in model with weights decompressed from model_path diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index e768e30c..b426cb97 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -23,7 +23,6 @@ get_nested_mappings_from_state_dict, get_nested_weight_mappings, merge_names, - remove_suffix, ) from safetensors import safe_open from torch import Tensor @@ -71,6 +70,7 @@ def compress( self, model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], + show_progress: bool = False, **kwargs, ) -> Dict[str, Tensor]: """ @@ -79,18 +79,21 @@ def compress( :param model_state: state dict of uncompressed model :param names_to_scheme: quantization args for each quantized weight, needed for quantize function to calculate bit depth + :param show_progress: whether to show tqdm progress :return: compressed state dict """ + uncompressed_names = list(model_state.keys()) compressed_dict = {} save_device = "cpu" - uncompressed_names = list(model_state.keys()) - for name in tqdm(uncompressed_names, desc="Compressing with quantization"): + # compress values + desc = "Compressing with quantization" + for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)): value = model_state[name] # compress weights if name.endswith("weight"): - prefix = remove_suffix(name, "weight") + prefix = name.removesuffix("weight") # gather qparams scale = model_state.get(prefix + "weight_scale", None) @@ -182,7 +185,7 @@ def decompress( ) else: - yield from self._decompress_from_state_dict( + yield from self.decompress_from_state_dict( path_to_model_or_tensors, names_to_scheme ) @@ -209,7 +212,11 @@ def _decompress_from_path( weight_data["weight"] = decompressed yield module_path, weight_data - def _decompress_from_state_dict(self, state_dict, names_to_scheme): + def decompress_from_state_dict( + self, + state_dict: Dict[str, torch.Tensor], + names_to_scheme: Dict[str, QuantizationScheme], + ) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]: weight_mappings = get_nested_mappings_from_state_dict( state_dict, self.compression_param_names ) @@ -219,7 +226,7 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme): weight_data[param_name] = param_value if "weight_scale" in weight_data: - quant_args = names_to_scheme[module_path] + quant_args = names_to_scheme[module_path].weights decompressed = self.decompress_weight( compressed_data=weight_data, quantization_args=quant_args ) diff --git a/src/compressed_tensors/compressors/sparse_compressors/base.py b/src/compressed_tensors/compressors/sparse_compressors/base.py index 33075819..e29b8284 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/base.py +++ b/src/compressed_tensors/compressors/sparse_compressors/base.py @@ -16,7 +16,11 @@ from typing import Dict, Generator, Optional, Set, Tuple from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.utils import get_nested_weight_mappings, merge_names +from compressed_tensors.utils import ( + get_nested_mappings_from_state_dict, + get_nested_weight_mappings, + merge_names, +) from safetensors import safe_open from torch import Tensor from tqdm import tqdm @@ -63,6 +67,7 @@ def compress( self, model_state: Dict[str, Tensor], compression_targets: Optional[Set[str]] = None, + show_progress: bool = False, ) -> Dict[str, Tensor]: """ Compresses a dense state dict using bitmask compression @@ -76,7 +81,11 @@ def compress( _LOGGER.debug( f"Compressing model with {len(model_state)} parameterized layers..." ) - for name, value in tqdm(model_state.items(), desc="Compressing model"): + for name, value in tqdm( + model_state.items(), + desc="Compressing with sparsity", + disable=(not show_progress), + ): if not self.should_compress(name, compression_targets): compressed_dict[name] = value continue @@ -124,15 +133,15 @@ def decompress( self.compression_param_names, return_unmatched_params=True, ) - for weight_name in weight_mappings.keys(): + for module_path in weight_mappings.keys(): weight_data = {} - for param_name, safe_path in weight_mappings[weight_name].items(): - full_name = merge_names(weight_name, param_name) + for param_name, safe_path in weight_mappings[module_path].items(): + full_name = merge_names(module_path, param_name) with safe_open(safe_path, framework="pt", device=device) as f: weight_data[param_name] = f.get_tensor(full_name) decompressed = self.decompress_weight(weight_data) - yield merge_names(weight_name, "weight"), decompressed + yield merge_names(module_path, "weight"), decompressed for ignored_param_name, safe_path in ignored_params.items(): should_skip = False @@ -146,6 +155,35 @@ def decompress( value = f.get_tensor(ignored_param_name) yield ignored_param_name, value + def decompress_from_state_dict( + self, + state_dict: Dict[str, Tensor], + ) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]: + """ + Decompress the state dict of a module (or model) + + Unlike `self.decompress`, this function does not need to explicitly skip params + via params_to_skip_load because it is more convenient for its only caller + (ModelCompressor.decompress_model) to retrieve all unused param keys + + :param state_dict: state dict containing parameters to decompress + :return: Generator of (param_path, param_val) + """ + weight_mappings, ignored_params = get_nested_mappings_from_state_dict( + state_dict, self.compression_param_names, return_unmatched_params=True + ) + + for module_path in weight_mappings.keys(): + weight_data = {} + for param_name, param_value in weight_mappings[module_path].items(): + weight_data[param_name] = param_value + + decompressed = self.decompress_weight(weight_data) + yield merge_names(module_path, "weight"), decompressed + + for ignored_param_path, ignored_param_value in ignored_params.items(): + yield ignored_param_path, ignored_param_value + @staticmethod def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool: """ diff --git a/src/compressed_tensors/compressors/sparse_compressors/dense.py b/src/compressed_tensors/compressors/sparse_compressors/dense.py index 2550d616..0ec2b5f6 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/dense.py +++ b/src/compressed_tensors/compressors/sparse_compressors/dense.py @@ -40,3 +40,10 @@ def decompress( self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs ) -> Generator[Tuple[str, Tensor], None, None]: return iter([]) + + def decompress_from_state_dict( + self, + state_dict: Dict[str, Tensor], + ) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]: + for key, value in state_dict.items(): + yield key, value diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index b21fb7fd..7a97faa3 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import Dict, Generator, List, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor @@ -202,11 +202,7 @@ def sparse24_bitmask_decompress( decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) decompressed_tensor = decompressed_tensor.to(values.device) values = values.flatten() - if decompressed_tensor.dtype == FP8_DTYPE: - decompressed_tensor[bytemasks_unpacked] = values - decompressed_tensor = decompressed_tensor.cuda() - else: - decompressed_tensor[bytemasks_unpacked] = values + decompressed_tensor[bytemasks_unpacked] = values return decompressed_tensor diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index a8487cb7..7b8fea02 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -125,6 +125,7 @@ def compress( self, model_state: Dict[str, Tensor], names_to_scheme: Dict[str, QuantizationScheme], + show_progress: bool = False, **kwargs, ) -> Dict[str, Tensor]: """ @@ -134,6 +135,7 @@ def compress( :param model_state: state dict of uncompressed model :param names_to_scheme: quantization scheme for each quantized weight, needed for quantize function to calculate bit depth + :param show_progress: whether to show tqdm progress :return: compressed state dict """ self.validate_quant_compatability(names_to_scheme) @@ -144,7 +146,9 @@ def compress( f"Compressing model with {len(model_state)} parameterized layers..." ) - for name, value in tqdm(model_state.items(), desc="Compressing model"): + for name, value in tqdm( + model_state.items(), desc="Compressing model", disable=(not show_progress) + ): if name.endswith(weight_suffix): prefix = name[: -(len(weight_suffix))] scale = model_state.get(merge_names(prefix, "weight_scale"), None) diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 65f992b6..d24df2fc 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -23,6 +23,7 @@ initialize_module_for_quantization, ) from compressed_tensors.utils import register_offload_parameter +from compressed_tensors.utils.offload import get_execution_device from torch import Tensor from torch.nn import Parameter from torch.nn.functional import linear @@ -60,7 +61,7 @@ def from_linear( """ module.__class__ = CompressedLinear module.compressor = BaseCompressor.load_from_registry(quantization_format) - device = next(module.parameters()).device + init_device = get_execution_device(module) # this will initialize all the scales and zero points initialize_module_for_quantization( @@ -79,7 +80,7 @@ def from_linear( # populate compressed weights and quantization parameters for name, (shape, dtype) in compression_params.items(): param = Parameter( - torch.empty(shape, device=device, dtype=dtype), requires_grad=False + torch.empty(shape, device=init_device, dtype=dtype), requires_grad=False ) register_offload_parameter(module, name, param) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index af0b9159..a842d00e 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -38,7 +38,6 @@ "shard_tensor", "pack_bitmasks", "unpack_bitmasks", - "remove_suffix", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -329,9 +328,3 @@ def unpack_bitmasks( ) return unpacked_bitmasks_torch - - -def remove_suffix(value: str, suffix: str) -> str: - # can replace with str.removesuffix in python3.9+ - assert value.endswith(suffix) - return value[: -len(suffix)] diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index f66680e2..13a8eb48 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -35,6 +35,7 @@ "is_quantization_param", ] +NestedStateDictType = Dict[str, Dict[str, Tensor]] WeightMappingType = Dict[str, str] NestedWeightMappingType = Dict[str, WeightMappingType] @@ -249,8 +250,10 @@ def get_nested_weight_mappings( def get_nested_mappings_from_state_dict( - state_dict, params_to_nest: Iterable[str] -) -> NestedWeightMappingType: + state_dict: Dict[str, Tensor], + params_to_nest: Iterable[str], + return_unmatched_params: bool = False, +) -> Union[NestedStateDictType, Tuple[NestedStateDictType, Dict[str, Tensor]]]: """ Takes a state dict and returns a nested mapping from uncompressed parameterized layer names to the value of @@ -266,16 +269,26 @@ def get_nested_mappings_from_state_dict( :param state_dict: state dict of the model :param params_to_nest: Iterable of parameter names to nest. :return: Nested mapping of parameterized layer names to the value of - each layer's compression parameters. + each layer's compression parameters. If `return_unmatched_params`, then + also return a dictionary mapping unused parameter names to their values """ nested_weight_mappings = {} + unmatched_params = {} + for key in state_dict.keys(): + matched = False for param_name in params_to_nest: module_path = match_param_name(key, param_name) if module_path: if module_path not in nested_weight_mappings: nested_weight_mappings[module_path] = {} nested_weight_mappings[module_path][param_name] = state_dict[key] + matched = True + if return_unmatched_params and not matched: + unmatched_params[key] = state_dict[key] + + if return_unmatched_params: + return nested_weight_mappings, unmatched_params return nested_weight_mappings diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index e0049f27..b1d040f4 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -24,6 +24,7 @@ from compressed_tensors.quantization import QuantizationConfig from safetensors.torch import save_file from tests.testing_utils import induce_sparsity, requires_hf_quantizer +from transformers import AutoModelForCausalLM def sparsity_config(): @@ -365,3 +366,115 @@ def _get_combined_config(s_config, q_config): combined["sparsity_config"] = s_config return combined + + +@pytest.mark.parametrize( + "model_stub,q_format,s_config", + [ + ( + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", + "float-quantized", + None, + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed", + None, + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed", + "float-quantized", + "sparse-24-bitmask", + ), + ( + "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed", + "pack-quantized", + None, + ), + ], +) +def test_compress_model(model_stub, q_format, s_config, tmpdir): + model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) + compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format) + + # compress model by eagerly compressing state dict + true_compressed = dict(compressor.compress(model)) + true_compressed = {key: value.clone() for key, value in true_compressed.items()} + + # compress model directly + compressor.compress_model(model) + compressed = dict(model.state_dict()) + + # equivalent to eagerly compressing state dict + assert compressed.keys() == true_compressed.keys() + for key in compressed.keys(): + assert compressed[key].dtype == true_compressed[key].dtype + assert torch.all(compressed[key] == true_compressed[key]), f"{key}" + + +@pytest.mark.parametrize( + "model_stub,comp_stub", + [ + ( + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed", + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed", + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed", + ), + ( + "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed", + "nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed", + ), + ( + "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed", + "nm-testing/llama2.c-stories15M-ultrachat-mixed-compressed", + ), + ], +) +def test_decompress_model(model_stub, comp_stub): + from transformers.utils.quantization_config import CompressedTensorsConfig + + # decompress from disk + # NOTE: transformers adds extra zero points if run_compressed=False or w/ sparsity + # https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L131-L133 + # however, decompression does not add zero points in non-asymmetric cases + # in order to normalize for this effect in this test, we remove empty weight zps + true_decompressed_model = AutoModelForCausalLM.from_pretrained( + comp_stub, + quantization_config=CompressedTensorsConfig(run_compressed=False), + torch_dtype=torch.float32, + ) + true_decompressed = dict(true_decompressed_model.state_dict()) + true_decompressed = remove_empty_weight_zero_points(true_decompressed) # see above + + # decompress from memory + # NOTE there is no other way to load a compressed model into memory, since + # there is no way to turn off decompression for sparse models + # https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L133 + model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) + compressor = ModelCompressor.from_pretrained(comp_stub) + compressor.compress_model(model) + compressor.decompress_model(model) + decompressed = dict(model.state_dict()) + + # remove keys not in model definition + # NOTE it would be better if compressors only returned keys to keep, rather than + # relying on the model structure + missing keys to catch and remove them later + model_keys = true_decompressed_model.state_dict().keys() + decompressed = {key: val for key, val in decompressed.items() if key in model_keys} + + # equivalent to decompressing from disk + assert decompressed.keys() == true_decompressed.keys() + for key in decompressed.keys(): + assert decompressed[key].dtype == true_decompressed[key].dtype + assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}" + + +def remove_empty_weight_zero_points(state_dict): + return { + name: value + for name, value in state_dict.items() + if not (name.endswith("weight_zero_point") and torch.all(value == 0)) + } diff --git a/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py index 0e28f004..1e95c5b4 100644 --- a/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +++ b/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py @@ -47,6 +47,8 @@ def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes): def validate_compression(dense_matrix, decompressed_tensor): """Validate that the decompressed tensor matches the original dense matrix.""" + if decompressed_tensor.dtype == FP8_DTYPE: + decompressed_tensor = decompressed_tensor.to("cuda") dense_matrix = dense_matrix.to(decompressed_tensor.device) assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch" assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch"