We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c423074 commit d699cc1Copy full SHA for d699cc1
python/mlx/nn/losses.py
@@ -1,22 +1,23 @@
1
# Copyright © 2023 Apple Inc.
2
3
import math
4
-from typing import Literal, Optional
+from typing import Literal, Optional, get_args
5
6
import mlx.core as mx
7
8
Reduction = Literal["none", "mean", "sum"]
9
10
11
def _reduce(loss: mx.array, reduction: Reduction = "none"):
12
+ if reduction not in get_args(Reduction):
13
+ raise ValueError(f"Invalid reduction. Must be one of {get_args(Reduction)}.")
14
+
15
if reduction == "mean":
16
return mx.mean(loss)
17
elif reduction == "sum":
18
return mx.sum(loss)
19
elif reduction == "none":
20
return loss
- else:
- raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
21
22
23
def cross_entropy(
0 commit comments