File tree Expand file tree Collapse file tree 3 files changed +3
-34
lines changed Expand file tree Collapse file tree 3 files changed +3
-34
lines changed Original file line number Diff line number Diff line change 3
3
from typing import Optional
4
4
5
5
import torch
6
- from accelerate .hooks import remove_hook_from_module
7
6
from compressed_tensors .utils import offloaded_dispatch
8
7
from loguru import logger
9
8
from torch .utils .data import DataLoader
@@ -128,8 +127,9 @@ def __init__(
128
127
129
128
# offload to cpu if possible
130
129
if "cuda" in str (model_args .oneshot_device ) and torch .cuda .is_available ():
131
- remove_hook_from_module (model_args .model , recurse = True )
132
- offloaded_dispatch (model_args .model , model_args .oneshot_device )
130
+ offloaded_dispatch (
131
+ model_args .model , execution_device = model_args .oneshot_device
132
+ )
133
133
else :
134
134
logger .warning ("CUDA is not available! Compressing model on CPU instead" )
135
135
Original file line number Diff line number Diff line change 3
3
from pathlib import PosixPath
4
4
from typing import Optional , Tuple
5
5
6
- from accelerate .hooks import remove_hook_from_module
7
6
from loguru import logger
8
7
from torch .nn import Module
9
8
from transformers import (
@@ -106,9 +105,6 @@ def post_process(
106
105
"Ex. `oneshot(..., output_dir=...)`"
107
106
)
108
107
109
- # Remove any existing hooks (maybe added by oneshot sequential onloading)
110
- remove_hook_from_module (model_args .model , recurse = True )
111
-
112
108
# Reset the one-time-use session upon completion
113
109
if recipe_args is not None and recipe_args .clear_sparse_session :
114
110
reset_session ()
Original file line number Diff line number Diff line change @@ -104,33 +104,6 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]
104
104
return None
105
105
106
106
107
- def quantization_memory_requirement (model : torch .nn .Module ) -> int :
108
- """
109
- Determines the max number of bytes needed to store quantization scale and zp data
110
-
111
- :param model: model to calculate requirements for
112
- :return: number of bytes required to reserve for quantization
113
- """
114
-
115
- total_elements = 0
116
- for _ , module in model .named_modules ():
117
- if isinstance (module , Linear ):
118
- for param in module .parameters ():
119
- # assume the max of group 128 and static scale/zp
120
- # TODO: base this on the recipe instead instead of assuming max
121
-
122
- # potentially just bias term
123
- max_quant_shape = param .shape [0 ] // 128
124
-
125
- if len (param .size ()) > 1 : # weights
126
- max_quant_shape *= param .shape [1 ]
127
-
128
- total_elements += max_quant_shape * 4
129
-
130
- bytes_ratio = 32 // 16 # assuming float16
131
- return total_elements * bytes_ratio
132
-
133
-
134
107
def infer_sparse_targets_and_ignores (
135
108
model : torch .nn .Module ,
136
109
sparsity_structure : str ,
You can’t perform that action at this time.
0 commit comments