Skip to content

Commit e7b7dc9

Browse files
authored
Add option to start training paused (#3420)
1 parent e2a0915 commit e7b7dc9

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

nerfstudio/engine/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig):
8686
"""Optionally log gradients during training"""
8787
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
8888
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
89+
start_paused: bool = False
90+
"""Whether to start the training in a paused state."""
8991

9092

9193
class Trainer:
@@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
121123
self.device += f":{local_rank}"
122124
self.mixed_precision: bool = self.config.mixed_precision
123125
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
124-
self.training_state: Literal["training", "paused", "completed"] = "training"
126+
self.training_state: Literal["training", "paused", "completed"] = (
127+
"paused" if self.config.start_paused else "training"
128+
)
125129
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
126130
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)
127131

@@ -361,7 +365,7 @@ def _init_viewer_state(self) -> None:
361365
assert self.viewer_state and self.pipeline.datamanager.train_dataset
362366
self.viewer_state.init_scene(
363367
train_dataset=self.pipeline.datamanager.train_dataset,
364-
train_state="training",
368+
train_state=self.training_state,
365369
eval_dataset=self.pipeline.datamanager.eval_dataset,
366370
)
367371

nerfstudio/viewer/viewer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def __init__(
103103
self.output_type_changed = True
104104
self.output_split_type_changed = True
105105
self.step = 0
106-
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
107-
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
106+
self.train_btn_state: Literal["training", "paused", "completed"] = (
107+
"training" if self.trainer is None else self.trainer.training_state
108+
)
109+
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
108110
self.last_move_time = 0
109111
# track the camera index that last being clicked
110112
self.current_camera_idx = 0
@@ -174,7 +176,11 @@ def __init__(
174176
)
175177
self.resume_train.on_click(lambda _: self.toggle_pause_button())
176178
self.resume_train.on_click(lambda han: self._toggle_training_state(han))
177-
self.resume_train.visible = False
179+
if self.train_btn_state == "training":
180+
self.resume_train.visible = False
181+
else:
182+
self.pause_train.visible = False
183+
178184
# Add buttons to toggle training image visibility
179185
self.hide_images = self.viser_server.gui.add_button(
180186
label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None

nerfstudio/viewer_legacy/server/viewer_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ def __init__(
116116
self.output_type_changed = True
117117
self.output_split_type_changed = True
118118
self.step = 0
119-
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
120-
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
119+
self.train_btn_state: Literal["training", "paused", "completed"] = (
120+
"training" if self.trainer is None else self.trainer.training_state
121+
)
122+
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
121123

122124
self.camera_message = None
123125

0 commit comments

Comments
 (0)