diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 88a0e5be9..1b12c1891 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -732,6 +732,7 @@ def encode_forward( # -> [batch_size, out_channels] x_lin = feat.values[:, start_idx:end_idx] @ self.weight_list[idx] x_lins.append(x_lin) + start_idx = end_idx # [batch_size, num_cols, out_channels] x = torch.stack(x_lins, dim=1) # [batch_size, num_cols, out_channels] + [num_cols, out_channels]