diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index dba8f557085..71e6dee3043 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -149,7 +149,7 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor: If the node contains many fake tensors, return the first one. """ if isinstance( - node.meta["val"], (Sequence, torch.fx.immutable_collections.immutable_list) + node.meta["val"], (list, tuple, torch.fx.immutable_collections.immutable_list) ): fake_tensor = node.meta["val"][0] else: