Skip to content

Commit 9fd74f7

Browse files
Update dynamic_weight_manager.py
1 parent 05c670e commit 9fd74f7

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
5151

5252
logger.info(
5353
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
54-
f" rank={self.rank}, ranks={self.nranks}, "
55-
f" load ipc weight from {self.ipc_path}.")
54+
f" rank={self.rank}, ranks={self.nranks}")
5655

5756
@paddle.no_grad()
5857
def _capture_model_state(self):
@@ -114,21 +113,25 @@ def _update_ipc_snapshot(self):
114113
logger.warning(
115114
"load model from no_reshard weight, maybe need more GPU memory"
116115
)
117-
logger.info("IPC snapshot update parameters completed")
116+
logger.info(
117+
f"IPC snapshot update parameters completed from {model_path}")
118118

119119
def _update_ipc(self):
120120
"""Update using standard IPC strategy (requires Training Worker)."""
121121
ipc_meta = paddle.load(self.ipc_path)
122122
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
123123
self._update_model_from_state(state_dict, "raw")
124-
logger.info("IPC update parameters completed")
124+
logger.info(
125+
f"IPC update parameters completed from file: {self.ipc_path}")
125126

126127
def _update_ipc_no_reshard(self):
127128
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
128129
ipc_meta = paddle.load(self.ipc_path)
129130
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
130131
self.models[0].set_state_dict(state_dict)
131-
logger.info("IPC no-reshard update parameters completed")
132+
logger.info(
133+
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
134+
)
132135

133136
def load_model(self) -> nn.Layer:
134137
"""Standard model loading without IPC."""
@@ -159,6 +162,8 @@ def clear_parameters(self, pid: int = 0) -> None:
159162
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor],
160163
src_type: str):
161164
"""Update model parameters from given state dictionary."""
165+
if len(state_dict) == 0:
166+
raise ValueError(f"No parameter found in state dict {state_dict}")
162167
update_count = 0
163168
for name, new_param in state_dict.items():
164169
if name not in self.state_dict:

0 commit comments

Comments
 (0)