We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 379cb75 commit 665dac0Copy full SHA for 665dac0
torchao/utils.py
@@ -314,6 +314,8 @@ def unwrap_tensor_subclass(model, filter_fn=None):
314
and type(child.weight) is not torch.nn.Parameter
315
and isinstance(child.weight, torch.Tensor)
316
and issubclass(type(child.weight), torch.Tensor)
317
+ and isinstance(child.weight, TorchAOBaseTensor)
318
+ and not parametrize.is_parametrized(child)
319
):
320
parametrize.register_parametrization(
321
child, "weight", UnwrapTensorSubclass()
0 commit comments