Skip to content

Commit 82c9ddb

Browse files
committed
Add tests, ensure that labels are set correctly
1 parent 174abc4 commit 82c9ddb

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/test_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def test_trainer_evaluate_with_strings(model: SetFitModel):
472472
# The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."
473473
model.predict(["another positive sentence"])
474474

475+
assert set(model.labels) == {"positive", "negative"}
476+
475477

476478
def test_trainer_evaluate_multilabel_f1():
477479
dataset = Dataset.from_dict({"text_new": ["", "a", "b", "ab"], "label_new": [[0, 0], [1, 0], [0, 1], [1, 1]]})
@@ -606,6 +608,7 @@ def test_evaluate_with_strings(model: SetFitModel) -> None:
606608
trainer.train()
607609
metrics = trainer.evaluate()
608610
assert "accuracy" in metrics
611+
assert set(model.labels) == {"positive", "negative"}
609612

610613

611614
def test_trainer_wrong_args(model: SetFitModel, tmp_path: Path) -> None:

0 commit comments

Comments
 (0)