Skip to content

Commit d699cc1

Browse files
authored
Fix unreachable warning (#1939)
* Fix unreachable warning * Update error message
1 parent c423074 commit d699cc1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

python/mlx/nn/losses.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
# Copyright © 2023 Apple Inc.
22

33
import math
4-
from typing import Literal, Optional
4+
from typing import Literal, Optional, get_args
55

66
import mlx.core as mx
77

88
Reduction = Literal["none", "mean", "sum"]
99

1010

1111
def _reduce(loss: mx.array, reduction: Reduction = "none"):
12+
if reduction not in get_args(Reduction):
13+
raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
14+
1215
if reduction == "mean":
1316
return mx.mean(loss)
1417
elif reduction == "sum":
1518
return mx.sum(loss)
1619
elif reduction == "none":
1720
return loss
18-
else:
19-
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
2021

2122

2223
def cross_entropy(

0 commit comments

Comments
 (0)