Skip to content

Commit b8d8ac0

Browse files
authored
fix NotImplementedError: get_type is not implemented (#2133)
Signed-off-by: Xin He <xinhe3@habana.ai>
1 parent 51a143e commit b8d8ac0

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,13 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
591591
setattr(self, "w2_weight", None)
592592
self.forward = self.forward_orig
593593

594+
def extra_repr(self) -> str:
595+
return extra_representation(
596+
self.extra_repr_org(),
597+
self.class_name_org,
598+
get_current_repr(self),
599+
)
600+
594601

595602
# This patched module is called by the vllm-mixtral FusedMoE layer
596603
# we wrap each expert weight with this module since FusedMoE has a single tensor for all experts weights
@@ -853,6 +860,13 @@ def update_measure(self, prev, cur, dim, idx, inp_seq_len):
853860
measure_output((output,), self._mod_extra_config.outputs)
854861
return output
855862

863+
def extra_repr(self) -> str:
864+
return extra_representation(
865+
self.extra_repr_org(),
866+
self.class_name_org,
867+
get_current_repr(self),
868+
)
869+
856870

857871
class PatchedVLLMKVCache(PatchedModuleBase):
858872
# Module to patch VLLMKVCache module from llama model
@@ -891,6 +905,14 @@ def fetch_from_cache(self, cache, blocks, permutations=None):
891905
output_cache = self.orig_fetch_from_cache(quant_cache, blocks)
892906
return self.dequant_output(output_cache)
893907

908+
def extra_repr(self) -> str:
909+
return extra_representation(
910+
self.extra_repr_org(),
911+
self.class_name_org,
912+
get_current_repr(self),
913+
)
914+
915+
894916
def init_conv(instance, mod_extra_config):
895917
if instance.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]:
896918
instance.quant_input = instance._mod_extra_config.inputs[0]

neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,32 +175,34 @@ def forward_quant(self, *args, **kwargs):
175175

176176
@classmethod
177177
def get_module_info(cls) -> ModuleInfo:
178-
"""Return the module info for the module.
178+
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
179+
Return the module info for the module, which is used to determine the scaling methods for the module.
179180
180181
For example, for linear module, the module info is: ModuleInfo(type="linear", patched_module=cls).
181182
"""
182183
return ModuleInfo(type=cls.get_type(), patched_module=cls)
183184

184185
@classmethod
185-
@abstractmethod
186186
def get_type(cls) -> str:
187-
"""Return the type of the patched module.
187+
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
188+
Return the type of the patched module, which is used to determine the scaling methods for the module.
188189
189190
Multiple patched modules can have the same type, and share the same scaling methods.
190191
"""
191192
raise NotImplementedError("`get_type` is not implemented")
192193

193194
@classmethod
194-
@abstractmethod
195195
def get_module_type(cls) -> ModuleType:
196-
"""Return the module type for the module.
196+
"""Only necessary for the newly registered patched module that doesn't in _mod_default_dict.
197+
Return the module type for the module, which is used to determine the number of inputs, outputs, and parameters of the module.
197198
198199
The module type is used to determine the number of inputs, outputs, and parameters of the module.
199200
For example, for linear module, the module type is: ModuleType(1, ["weight"], 1, False).
200201
"""
201202
raise NotImplementedError("`get_module_type` is not implemented")
202203

203204
def extra_repr(self):
205+
"""This extra_repr is only for the newly registered patched module that doesn't in _mod_default_dict."""
204206
return f"quantization_mode={self.quantization_mode}, " + \
205207
f"module_info={self.get_module_info()}, " + \
206208
f"module_type={self.get_module_type()}"

0 commit comments

Comments
 (0)