Skip to content

Commit f9e2ab0

Browse files
authored
[docs, tune] replace reuse actors example with a fuller demonstration (#51234)
makes the Tune docs for re-using actors an e2e example --------- Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
1 parent ec774fe commit f9e2ab0

File tree

1 file changed

+56
-15
lines changed

1 file changed

+56
-15
lines changed

doc/source/tune/api/trainable.rst

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,29 +107,70 @@ It is up to the user to correctly update the hyperparameters of your trainable.
107107

108108
.. code-block:: python
109109
110-
class PytorchTrainable(tune.Trainable):
111-
"""Train a Pytorch ConvNet."""
110+
from time import sleep
111+
import ray
112+
from ray import tune
113+
from ray.tune.tuner import Tuner
112114
115+
116+
def expensive_setup():
117+
print("EXPENSIVE SETUP")
118+
sleep(1)
119+
120+
121+
class QuadraticTrainable(tune.Trainable):
113122
def setup(self, config):
114-
self.train_loader, self.test_loader = get_data_loaders()
115-
self.model = ConvNet()
116-
self.optimizer = optim.SGD(
117-
self.model.parameters(),
118-
lr=config.get("lr", 0.01),
119-
momentum=config.get("momentum", 0.9))
123+
self.config = config
124+
expensive_setup() # use reuse_actors=True to only run this once
125+
self.max_steps = 5
126+
self.step_count = 0
120127
121-
def reset_config(self, new_config):
122-
for param_group in self.optimizer.param_groups:
123-
if "lr" in new_config:
124-
param_group["lr"] = new_config["lr"]
125-
if "momentum" in new_config:
126-
param_group["momentum"] = new_config["momentum"]
128+
def step(self):
129+
# Extract hyperparameters from the config
130+
h1 = self.config["hparam1"]
131+
h2 = self.config["hparam2"]
132+
133+
# Compute a simple quadratic objective where the optimum is at hparam1=3 and hparam2=5
134+
loss = (h1 - 3) ** 2 + (h2 - 5) ** 2
135+
136+
metrics = {"loss": loss}
137+
138+
self.step_count += 1
139+
if self.step_count > self.max_steps:
140+
metrics["done"] = True
127141
128-
self.model = ConvNet()
142+
# Return the computed loss as the metric
143+
return metrics
144+
145+
def reset_config(self, new_config):
146+
# Update the configuration for a new trial while reusing the actor
129147
self.config = new_config
130148
return True
131149
132150
151+
ray.init()
152+
153+
154+
tuner_with_reuse = Tuner(
155+
QuadraticTrainable,
156+
param_space={
157+
"hparam1": tune.uniform(-10, 10),
158+
"hparam2": tune.uniform(-10, 10),
159+
},
160+
tune_config=tune.TuneConfig(
161+
num_samples=10,
162+
max_concurrent_trials=1,
163+
reuse_actors=True, # Enable actor reuse and avoid expensive setup
164+
),
165+
run_config=ray.tune.RunConfig(
166+
verbose=0,
167+
checkpoint_config=ray.tune.CheckpointConfig(checkpoint_at_end=False),
168+
),
169+
)
170+
tuner_with_reuse.fit()
171+
172+
173+
133174
Comparing Tune's Function API and Class API
134175
-------------------------------------------
135176

0 commit comments

Comments
 (0)