Skip to content

Commit 85f40b5

Browse files
bugfixes (#375)
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 3c55003 commit 85f40b5

File tree

2 files changed

+3
-9
lines changed
  • src/compressed_tensors

2 files changed

+3
-9
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,7 @@ def apply_quantization_config(
152152
# list of submodules to ignore
153153
ignored_submodules = defaultdict(list)
154154
# mark appropriate layers for quantization by setting their quantization schemes
155-
for name, submodule in iter_named_quantizable_modules(
156-
model,
157-
include_children=True,
158-
include_attn=True,
159-
): # child modules and attention modules
155+
for name, submodule in model.named_modules(): # child modules and attention modules
160156
# potentially fix module name to remove FSDP wrapper prefix
161157
name = fix_fsdp_module_name(name)
162158
if matches := find_name_or_class_matches(name, submodule, config.ignore):

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
9999
# create transform as submodule
100100
transform_name = f"{self.name}_{args.location.value}"
101101
transform = self.create_transform(module, args)
102-
register_offload_module(module, transform_name, transform) # (1)
103102

104103
# register input transformation hook
105104
if args.location == TransformLocation.INPUT:
105+
register_offload_module(module, transform_name, transform)
106106

107107
def input_hook(_, args):
108108
input = args[0]
@@ -130,6 +130,7 @@ def input_hook(_, args):
130130

131131
# register output transformation hook
132132
elif args.location == TransformLocation.OUTPUT:
133+
register_offload_module(module, transform_name, transform)
133134

134135
def output_hook(_, _input, output):
135136
return transform(output)
@@ -140,9 +141,6 @@ def output_hook(_, _input, output):
140141
else:
141142
raise NotImplementedError()
142143

143-
# (1) even in the `weight` cases, this submodule attachment is needed in order
144-
# to support saving in the frozen state
145-
146144

147145
class TransformBase(Module, ABC):
148146
"""

0 commit comments

Comments
 (0)