From bdd4b76b5c2f755829b59d472c19320793044448 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 20 Aug 2025 09:02:12 +0000 Subject: [PATCH] remove linear op assert --- internlm/model/modules/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 3ae6a5a1..828b9cf8 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -912,13 +912,13 @@ def __init__( # pylint: disable=W0231, W0233 torch.empty(num_groups, in_features, local_multiple * multiple_of, device=device, dtype=dtype) ) self.tp_dim = 2 - assert self.weight.shape[self.tp_dim] != out_features + # assert self.weight.shape[self.tp_dim] != out_features elif split_mode == "row": self.weight = nn.Parameter( torch.empty(num_groups, local_multiple * multiple_of, out_features, device=device, dtype=dtype) ) self.tp_dim = 1 - assert self.weight.shape[self.tp_dim] != in_features + # assert self.weight.shape[self.tp_dim] != in_features elif split_mode == "weight": self.weight = nn.Parameter( torch.empty(local_multiple * multiple_of, out_features, device=device, dtype=dtype)