@@ -70,8 +70,8 @@ def __init__(
70
70
71
71
def fit (self , X : ArrayLike , y : ArrayLike , verbose = True ,
72
72
cache_linear_coefs : bool = True ,
73
- cache_embs_dir : str = None ,
74
- ):
73
+ cache_embs_dir : str = None ,
74
+ ):
75
75
'''Extract embeddings then fit linear model
76
76
77
77
Parameters
@@ -195,7 +195,7 @@ def cache_linear_coefs(self, X: ArrayLike, model=None, tokenizer_embeddings=None
195
195
embs = self .normalizer .transform (embs )
196
196
197
197
# save coefs
198
- coef_embs = self .linear .coef_ .squeeze ()
198
+ coef_embs = self .linear .coef_ .squeeze (). transpose ()
199
199
linear_coef = embs @ coef_embs
200
200
self .coefs_dict_ = {
201
201
** coefs_dict_old ,
@@ -225,16 +225,22 @@ def predict(self, X, warn=True):
225
225
if isinstance (self , RegressorMixin ):
226
226
return preds
227
227
elif isinstance (self , ClassifierMixin ):
228
- return ((preds + self .linear .intercept_ ) > 0 ).astype (int )
228
+ if preds .ndim > 1 : # multiclass classification
229
+ return np .argmax (preds , axis = 1 )
230
+ else :
231
+ return (preds + self .linear .intercept_ > 0 ).astype (int )
229
232
230
233
def predict_proba (self , X , warn = True ):
231
234
if not isinstance (self , ClassifierMixin ):
232
235
raise Exception (
233
236
"predict_proba only available for EmbGAMClassifier" )
234
237
check_is_fitted (self )
235
238
preds = self ._predict_cached (X , warn = warn )
236
- logits = np .vstack (
237
- (1 - preds , preds )).transpose ()
239
+ if preds .ndim > 1 : # multiclass classification
240
+ logits = preds
241
+ else :
242
+ logits = np .vstack (
243
+ (1 - preds , preds )).transpose ()
238
244
return softmax (logits , axis = 1 )
239
245
240
246
def _predict_cached (self , X , warn ):
0 commit comments