Skip to content

Commit 895a123

Browse files
authored
fix muon document (#21079)
* fix muon argument * fix muon argument * change behavior * add some test * add some test * fix * fix
1 parent 3da0abd commit 895a123

File tree

2 files changed

+89
-6
lines changed

2 files changed

+89
-6
lines changed

keras/src/optimizers/muon.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
from keras.src import ops
24
from keras.src.api_export import keras_export
35
from keras.src.optimizers import optimizer
@@ -124,10 +126,7 @@ def __init__(
124126
self.ns_steps = ns_steps
125127
self.nesterov = nesterov
126128
self.exclude_embeddings = exclude_embeddings
127-
# exclude_layers is a keyword at variable path
128-
# so it must be a string
129-
assert isinstance(exclude_layers, str) or exclude_layers is None
130-
self.exclude_layers = exclude_layers.lower()
129+
self.exclude_layers = exclude_layers or []
131130

132131
def _should_use_adamw(self, variable):
133132
# To use it with 4D convolutional filters,
@@ -137,8 +136,9 @@ def _should_use_adamw(self, variable):
137136
return True
138137
if self.exclude_embeddings and "embedding" in variable.path.lower():
139138
return True
140-
if self.exclude_layers in variable.path.lower():
141-
return True
139+
for keyword in self.exclude_layers:
140+
if re.search(keyword, variable.path):
141+
return True
142142
return False
143143

144144
def build(self, var_list):

keras/src/optimizers/muon_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import numpy as np
2+
3+
from keras.src import backend
4+
from keras.src import ops
5+
from keras.src import testing
6+
from keras.src.layers import Dense
7+
from keras.src.layers import Embedding
8+
from keras.src.optimizers.muon import Muon
9+
10+
11+
class MuonTest(testing.TestCase):
12+
def test_config(self):
13+
optimizer = Muon(
14+
learning_rate=0.5,
15+
epsilon=1e-5,
16+
)
17+
self.run_class_serialization_test(optimizer)
18+
19+
def test_Newton_Schulz(self):
20+
optimizer = Muon()
21+
tensor_input = ops.array([[0.2499, 0.9105], [0.2655, 0.8824]])
22+
except_output = ops.array([[-0.4422, 0.6457], [0.7285, 0.2968]])
23+
output = optimizer.zeropower_via_newtonschulz5(tensor_input, 5)
24+
self.assertAllClose(output, except_output, rtol=1e-3, atol=1e-3)
25+
26+
def test_adamw_single_step(self):
27+
optimizer = Muon()
28+
grads = ops.array([1.0, 6.0, 7.0, 2.0])
29+
vars = backend.Variable([1.0, 2.0, 3.0, 4.0], name="test_vars")
30+
optimizer.build([vars])
31+
optimizer._adamw_update_step(grads, vars, 0.5)
32+
self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)
33+
34+
def test_should_use_adamw(self):
35+
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
36+
optimizer = Muon(exclude_layers=["var"])
37+
self.assertAllClose(
38+
True,
39+
optimizer._should_use_adamw(vars),
40+
)
41+
embeding = Embedding(2, 2)
42+
embeding.build()
43+
self.assertAllClose(
44+
True,
45+
optimizer._should_use_adamw(embeding.weights[0]),
46+
)
47+
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
48+
optimizer = Muon()
49+
self.assertAllClose(
50+
False,
51+
optimizer._should_use_adamw(vars),
52+
)
53+
dense = Dense(2)
54+
dense.build([None, 2])
55+
self.assertAllClose(
56+
False,
57+
optimizer._should_use_adamw(dense.weights[0]),
58+
)
59+
60+
def test_muon_single_step(self):
61+
optimizer = Muon(
62+
learning_rate=0.5,
63+
weight_decay=0,
64+
)
65+
grads = ops.array([[1.0, 6.0], [7.0, 2.0]])
66+
vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
67+
optimizer.build([vars])
68+
optimizer._muon_update_step(grads, vars, 0.5)
69+
self.assertAllClose(
70+
vars, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2
71+
)
72+
73+
def test_clip_norm(self):
74+
optimizer = Muon(clipnorm=1)
75+
grad = [np.array([100.0, 100.0])]
76+
clipped_grad = optimizer._clip_gradients(grad)
77+
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
78+
79+
def test_clip_value(self):
80+
optimizer = Muon(clipvalue=1)
81+
grad = [np.array([100.0, 100.0])]
82+
clipped_grad = optimizer._clip_gradients(grad)
83+
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

0 commit comments

Comments
 (0)