Skip to content

Commit 6494abb

Browse files
committed
output category values instead of their indices
Fixes #6
1 parent 6d9b7ef commit 6494abb

File tree

4 files changed

+42
-14
lines changed

4 files changed

+42
-14
lines changed

ebm2onnx/convert.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,29 @@ def to_onnx(model, dtype, name="ebm",
168168
# compute scores, predict and proba
169169
g = graph.merge(*parts)
170170
if type(model) is ExplainableBoostingClassifier:
171+
class_type = onnx.TensorProto.STRING if model.classes_.dtype.type is np.str_ else onnx.TensorProto.INT64
172+
classes=model.classes_
173+
if class_type == onnx.TensorProto.STRING:
174+
classes=[ c.encode("utf-8") for c in classes]
175+
171176
g, scores_output_name = ebm.compute_class_score(model.intercept_, explain_name)(g)
172177
g_scores = graph.strip_to_transients(g)
173178
if len(model.classes_) == 2: # binary classification
174-
g = ebm.predict_class(binary=True, prediction_name=prediction_name)(g)
175-
g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64, [None])
179+
g = ebm.predict_class(
180+
classes=classes, class_type=class_type,
181+
binary=True, prediction_name=prediction_name
182+
)(g)
183+
g = graph.add_output(g, g.transients[0].name, class_type, [None])
176184
if predict_proba is True:
177185
gp = ebm.predict_proba(binary=True, probabilities_name=probabilities_name)(g_scores)
178186
g = graph.merge(graph.clear_transients(g), gp)
179187
g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.FLOAT, [None, len(model.classes_)])
180188
else:
181-
g = ebm.predict_class(binary=False, prediction_name=prediction_name)(g)
182-
g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64, [None])
189+
g = ebm.predict_class(
190+
classes=classes, class_type=class_type,
191+
binary=False, prediction_name=prediction_name
192+
)(g)
193+
g = graph.add_output(g, g.transients[0].name, class_type, [None])
183194
if predict_proba is True:
184195
gp = ebm.predict_proba(binary=False, probabilities_name=probabilities_name)(g_scores)
185196
g = graph.merge(graph.clear_transients(g), gp)

ebm2onnx/ebm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def _compute_class_score(g):
115115
return _compute_class_score
116116

117117

118-
def predict_class(binary, prediction_name):
118+
def predict_class(classes, class_type, binary, prediction_name):
119+
_classes = classes
119120
def _predict_class(g):
120121
if binary is True:
121122
init_zeros = graph.create_initializer(
@@ -130,7 +131,13 @@ def _predict_class(g):
130131
[1], [0],
131132
)
132133

133-
g = ops.argmax(axis=1)(g)
134+
classes = graph.create_initializer(
135+
g, "classes", class_type,
136+
[len(_classes)], _classes,
137+
)
138+
139+
g = ops.argmax(axis=1)(g) # fetch class index with highest score
140+
g = ops.gather_nd()(graph.merge(classes, g)) # retrieve class name from index
134141
g = ops.reshape()(graph.merge(g, init_reshape))
135142
g = ops.identity(prediction_name, suffix=False)(g)
136143
return g

tests/test_convert.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def train_titanic_regression(interactions):
5959
return model, x_test, y_test
6060

6161

62-
def train_bank_churners_multiclass_classification():
62+
def train_bank_churners_multiclass_classification(encode_label=True):
6363
df = pd.read_csv(
6464
os.path.join('examples','BankChurners.csv'),
6565
)
@@ -69,8 +69,11 @@ def train_bank_churners_multiclass_classification():
6969
label_column = "Income_Category"
7070

7171
y = df[[label_column]]
72-
le = LabelEncoder()
73-
y_enc = le.fit_transform(y)
72+
if encode_label:
73+
le = LabelEncoder()
74+
y_enc = le.fit_transform(y)
75+
else:
76+
y_enc = y
7477
x = df[feature_columns]
7578
x_train, x_test, y_train, y_test = train_test_split(x, y_enc)
7679
model = ExplainableBoostingClassifier(interactions=0, feature_types=feature_types)
@@ -197,8 +200,9 @@ def test_predict_binary_classification_with_categorical(interactions, explain):
197200
assert np.allclose(pred_ebm, pred_onnx[0])
198201

199202

200-
def test_predict_multiclass_classification():
201-
model_ebm, x_test, y_test = train_bank_churners_multiclass_classification()
203+
@pytest.mark.parametrize("encode_label", [False, True])
204+
def test_predict_multiclass_classification(encode_label):
205+
model_ebm, x_test, y_test = train_bank_churners_multiclass_classification(encode_label=encode_label)
202206
pred_ebm = model_ebm.predict(x_test)
203207

204208
model_onnx = ebm2onnx.to_onnx(
@@ -218,7 +222,7 @@ def test_predict_multiclass_classification():
218222
'Credit_Limit': x_test['Credit_Limit'].values,
219223
})
220224

221-
assert np.allclose(pred_ebm, pred_onnx[0])
225+
assert (pred_ebm == pred_onnx[0]).all()
222226

223227

224228
def test_predict_proba_multiclass_classification():

tests/test_ebm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ def test_predict_class_binary():
215215
g = graph.create_graph()
216216
i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 1])
217217

218-
g = ebm.predict_class(binary=True, prediction_name="prediction")(i)
218+
g = ebm.predict_class(
219+
classes=[0, 1], class_type=onnx.TensorProto.INT64,
220+
binary=True, prediction_name="prediction"
221+
)(i)
219222
g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64, [None])
220223

221224
assert_model_result(g,
@@ -230,7 +233,10 @@ def test_predict_multiclass_binary():
230233
g = graph.create_graph()
231234
i = graph.create_input(g, "i", onnx.TensorProto.FLOAT, [None, 3])
232235

233-
g = ebm.predict_class(binary=False, prediction_name="prediction")(i)
236+
g = ebm.predict_class(
237+
classes=[0, 1, 2], class_type=onnx.TensorProto.INT64,
238+
binary=False, prediction_name="prediction"
239+
)(i)
234240
g = graph.add_output(g, g.transients[0].name, onnx.TensorProto.INT64, [None])
235241

236242
assert_model_result(g,

0 commit comments

Comments
 (0)