Skip to content

Commit 80f8c37

Browse files
authored
fix val step
1 parent 82f417c commit 80f8c37

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

examples/llm/txt2kg_rag.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ def train(args, train_loader, val_loader):
589589
model.eval()
590590
with torch.no_grad():
591591
for step, batch in enumerate(val_loader):
592+
new_qs = []
593+
for i, q in enumerate(batch["question"]):
594+
# insert VectorRAG context
595+
new_qs.append(
596+
prompt_template.format(
597+
question=q,
598+
context="\n".join(batch.text_context[i])))
599+
batch.question = new_qs
592600
if args.skip_graph_rag:
593601
batch.desc = ""
594602
loss = get_loss(model, batch)

0 commit comments

Comments
 (0)