@@ -107,29 +107,70 @@ It is up to the user to correctly update the hyperparameters of your trainable.
107
107
108
108
.. code-block :: python
109
109
110
- class PytorchTrainable (tune .Trainable ):
111
- """ Train a Pytorch ConvNet."""
110
+ from time import sleep
111
+ import ray
112
+ from ray import tune
113
+ from ray.tune.tuner import Tuner
112
114
115
+
116
+ def expensive_setup ():
117
+ print (" EXPENSIVE SETUP" )
118
+ sleep(1 )
119
+
120
+
121
+ class QuadraticTrainable (tune .Trainable ):
113
122
def setup (self , config ):
114
- self .train_loader, self .test_loader = get_data_loaders()
115
- self .model = ConvNet()
116
- self .optimizer = optim.SGD(
117
- self .model.parameters(),
118
- lr = config.get(" lr" , 0.01 ),
119
- momentum = config.get(" momentum" , 0.9 ))
123
+ self .config = config
124
+ expensive_setup() # use reuse_actors=True to only run this once
125
+ self .max_steps = 5
126
+ self .step_count = 0
120
127
121
- def reset_config (self , new_config ):
122
- for param_group in self .optimizer.param_groups:
123
- if " lr" in new_config:
124
- param_group[" lr" ] = new_config[" lr" ]
125
- if " momentum" in new_config:
126
- param_group[" momentum" ] = new_config[" momentum" ]
128
+ def step (self ):
129
+ # Extract hyperparameters from the config
130
+ h1 = self .config[" hparam1" ]
131
+ h2 = self .config[" hparam2" ]
132
+
133
+ # Compute a simple quadratic objective where the optimum is at hparam1=3 and hparam2=5
134
+ loss = (h1 - 3 ) ** 2 + (h2 - 5 ) ** 2
135
+
136
+ metrics = {" loss" : loss}
137
+
138
+ self .step_count += 1
139
+ if self .step_count > self .max_steps:
140
+ metrics[" done" ] = True
127
141
128
- self .model = ConvNet()
142
+ # Return the computed loss as the metric
143
+ return metrics
144
+
145
+ def reset_config (self , new_config ):
146
+ # Update the configuration for a new trial while reusing the actor
129
147
self .config = new_config
130
148
return True
131
149
132
150
151
+ ray.init()
152
+
153
+
154
+ tuner_with_reuse = Tuner(
155
+ QuadraticTrainable,
156
+ param_space = {
157
+ " hparam1" : tune.uniform(- 10 , 10 ),
158
+ " hparam2" : tune.uniform(- 10 , 10 ),
159
+ },
160
+ tune_config = tune.TuneConfig(
161
+ num_samples = 10 ,
162
+ max_concurrent_trials = 1 ,
163
+ reuse_actors = True , # Enable actor reuse and avoid expensive setup
164
+ ),
165
+ run_config = ray.tune.RunConfig(
166
+ verbose = 0 ,
167
+ checkpoint_config = ray.tune.CheckpointConfig(checkpoint_at_end = False ),
168
+ ),
169
+ )
170
+ tuner_with_reuse.fit()
171
+
172
+
173
+
133
174
Comparing Tune's Function API and Class API
134
175
-------------------------------------------
135
176
0 commit comments