Skip to content

Commit 13f6142

Browse files
committed
add embgam multiclass support
1 parent 7b02814 commit 13f6142

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

imodelsx/embgam/embgam.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __init__(
7070

7171
def fit(self, X: ArrayLike, y: ArrayLike, verbose=True,
7272
cache_linear_coefs: bool = True,
73-
cache_embs_dir: str=None,
74-
):
73+
cache_embs_dir: str = None,
74+
):
7575
'''Extract embeddings then fit linear model
7676
7777
Parameters
@@ -195,7 +195,7 @@ def cache_linear_coefs(self, X: ArrayLike, model=None, tokenizer_embeddings=None
195195
embs = self.normalizer.transform(embs)
196196

197197
# save coefs
198-
coef_embs = self.linear.coef_.squeeze()
198+
coef_embs = self.linear.coef_.squeeze().transpose()
199199
linear_coef = embs @ coef_embs
200200
self.coefs_dict_ = {
201201
**coefs_dict_old,
@@ -225,16 +225,22 @@ def predict(self, X, warn=True):
225225
if isinstance(self, RegressorMixin):
226226
return preds
227227
elif isinstance(self, ClassifierMixin):
228-
return ((preds + self.linear.intercept_) > 0).astype(int)
228+
if preds.ndim > 1: # multiclass classification
229+
return np.argmax(preds, axis=1)
230+
else:
231+
return (preds + self.linear.intercept_ > 0).astype(int)
229232

230233
def predict_proba(self, X, warn=True):
231234
if not isinstance(self, ClassifierMixin):
232235
raise Exception(
233236
"predict_proba only available for EmbGAMClassifier")
234237
check_is_fitted(self)
235238
preds = self._predict_cached(X, warn=warn)
236-
logits = np.vstack(
237-
(1 - preds, preds)).transpose()
239+
if preds.ndim > 1: # multiclass classification
240+
logits = preds
241+
else:
242+
logits = np.vstack(
243+
(1 - preds, preds)).transpose()
238244
return softmax(logits, axis=1)
239245

240246
def _predict_cached(self, X, warn):

0 commit comments

Comments
 (0)