diff --git a/class_balanced_loss.py b/class_balanced_loss.py index 179274d..4432520 100644 --- a/class_balanced_loss.py +++ b/class_balanced_loss.py @@ -84,7 +84,7 @@ def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gam if loss_type == "focal": cb_loss = focal_loss(labels_one_hot, logits, weights, gamma) elif loss_type == "sigmoid": - cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) + cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weight = weights) elif loss_type == "softmax": pred = logits.softmax(dim = 1) cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)