diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 3ae6a5a1..828b9cf8 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -912,13 +912,13 @@ def __init__( # pylint: disable=W0231, W0233 torch.empty(num_groups, in_features, local_multiple * multiple_of, device=device, dtype=dtype) ) self.tp_dim = 2 - assert self.weight.shape[self.tp_dim] != out_features + # assert self.weight.shape[self.tp_dim] != out_features elif split_mode == "row": self.weight = nn.Parameter( torch.empty(num_groups, local_multiple * multiple_of, out_features, device=device, dtype=dtype) ) self.tp_dim = 1 - assert self.weight.shape[self.tp_dim] != in_features + # assert self.weight.shape[self.tp_dim] != in_features elif split_mode == "weight": self.weight = nn.Parameter( torch.empty(local_multiple * multiple_of, out_features, device=device, dtype=dtype)