Skip to content

Commit b127f5d

Browse files
author
jacky
committed
update
1 parent 07490d8 commit b127f5d

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

xtuner/v1/datasets/sft_tokenize_fn/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def __init__(
3737
self.max_length = max_length
3838

3939
def __call__(self, item: dict | list, **kwargs) -> DataItem | CacheItem:
40-
if isinstance(item, dict) and "messages" in item:
41-
item = item["messages"]
40+
if isinstance(item, dict) and ("messages" in item or "dialogs" in item):
41+
item = item.get("messages", item.get("dialogs"))
4242
messages = ChatMessages(messages=item)
4343
tokenized = messages.tokenize(self.tokenizer, self.chat_template)
4444

xtuner/v1/ops/attn_imp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torc
213213
if flash_attn_exception is not None:
214214
traceback.print_exception(flash_attn_exception)
215215
raise flash_attn_exception
216-
attention_output = flash_attn_varlen_func(q, k, v, **kwargs)
216+
attention_output = flash_attn_varlen_func(q, k, v, window_size=window_size, **kwargs)
217217
else:
218218
if flash_sink_attn_exception is not None:
219219
traceback.print_exception(flash_sink_attn_exception)

xtuner/v1/rl/base/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _build_ref_model(
187187
if ref_model_fsdp_cfg is None:
188188
ref_model_fsdp_cfg = FSDPConfig(recompute_ratio=0, cpu_offload=False, requires_grad=False)
189189
model = model.fully_shard(ref_model_fsdp_cfg, float8_handler)
190-
model.from_hf(hf_path=load_from)
190+
model.from_hf(hf_path=load_from, strict=False)
191191
model.eval()
192192
if float8_handler is not None:
193193
# As the ref model is not updated, we only compute params' scales once

xtuner/v1/train/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,8 +986,10 @@ def _log_step(
986986
grad_norm: float,
987987
):
988988
"""Log the training step information."""
989-
tgs = step_consumed_tokens / step_time
990-
e2e_tgs = total_consumed_tokens / train_time
989+
if step_consumed_tokens == 0:
990+
logger.warning("step_consumed_tokens is 0 due to padding")
991+
tgs = max(1, step_consumed_tokens / step_time)
992+
e2e_tgs = max(1, total_consumed_tokens / train_time)
991993
lr = self._lr_scheduler.get_last_lr()[0]
992994

993995
remaining_steps = self.total_step - self.cur_step

0 commit comments

Comments
 (0)