|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from torch import nn
|
6 |
| -from torch.ao.pruning import BaseSparsifier |
| 6 | +from torch.ao.pruning import BaseSparsifier, get_arg_info_from_tensor_fqn |
7 | 7 | from torch.ao.quantization import QConfig, default_placeholder_observer
|
8 | 8 | from torch.ao.quantization.quantize import _remove_qconfig
|
9 | 9 |
|
@@ -47,9 +47,27 @@ def __init__(
|
47 | 47 | def prepare(self, model: nn.Module, config: List[Dict]) -> None:
|
48 | 48 | # activation: use PerChannelNormObserver
|
49 | 49 | # use no-op placeholder weight observer
|
50 |
| - model.qconfig = QConfig( |
51 |
| - activation=PerChannelNormObserver, weight=default_placeholder_observer |
52 |
| - ) # type: ignore[assignment] |
| 50 | + if config is None: |
| 51 | + # If no config is provided, apply the qconfig to the entire model |
| 52 | + model.qconfig = QConfig( |
| 53 | + activation=PerChannelNormObserver, weight=default_placeholder_observer |
| 54 | + ) # type: ignore[assignment] |
| 55 | + else: |
| 56 | + for module_config in config: |
| 57 | + tensor_fqn = module_config.get("tensor_fqn", None) |
| 58 | + if tensor_fqn is None: |
| 59 | + raise ValueError("Each config must contain a 'tensor_fqn'.") |
| 60 | + |
| 61 | + # Extract module information from tensor_fqn |
| 62 | + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) |
| 63 | + module = info_from_tensor_fqn["module"] |
| 64 | + |
| 65 | + # Apply the qconfig directly to the module if it exists |
| 66 | + if module is not None: |
| 67 | + module.qconfig = QConfig( |
| 68 | + activation=PerChannelNormObserver, |
| 69 | + weight=default_placeholder_observer, |
| 70 | + ) # type: ignore[assignment] |
53 | 71 | torch.ao.quantization.prepare(model, inplace=True)
|
54 | 72 |
|
55 | 73 | # call superclass prepare
|
|
0 commit comments