|
4 | 4 |
|
5 | 5 | # Imports from standard library
|
6 | 6 | from typing import List, Optional
|
7 |
| - |
8 |
| -# Imports from Langchain |
| 7 | +from functools import reduce |
| 8 | +import operator |
9 | 9 | from langchain.prompts import PromptTemplate
|
10 | 10 | from langchain_core.output_parsers import JsonOutputParser
|
11 | 11 | from tqdm import tqdm
|
@@ -68,20 +68,51 @@ def execute(self, state: dict) -> dict:
|
68 | 68 |
|
69 | 69 | self.logger.info(f"--- Executing {self.node_name} Node ---")
|
70 | 70 |
|
71 |
| - template_answer = "" |
| 71 | + # merge the answers in one string |
| 72 | + |
| 73 | + |
| 74 | + # Initialize the output parser |
| 75 | + if self.node_config.get("schema", None) is not None: |
| 76 | + output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) |
| 77 | + else: |
| 78 | + output_parser = JsonOutputParser() |
| 79 | + |
| 80 | + format_instructions = output_parser.get_format_instructions() |
| 81 | + |
| 82 | + template_answer = """ |
| 83 | + You are a website scraper and you have just scraped some content from multiple websites.\n |
| 84 | + You are now asked to provide an answer to a USER PROMPT based on the content you have scraped.\n |
| 85 | + You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n |
| 86 | + The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n |
| 87 | + OUTPUT INSTRUCTIONS: {format_instructions}\n |
| 88 | + USER PROMPT: {user_prompt}\n |
| 89 | + WEBSITE CONTENT: {website_content} |
| 90 | + """ |
| 91 | + |
| 92 | + input_keys = self.get_input_keys(state) |
| 93 | + # Fetching data from the state based on the input keys |
| 94 | + input_data = [state[key] for key in input_keys] |
| 95 | + |
| 96 | + user_prompt = input_data[0] |
| 97 | + #answers is a list of strings |
| 98 | + answers, relevant_links = zip(*input_data[1]) |
| 99 | + |
| 100 | + answers_str = "" |
| 101 | + for i, answer in enumerate(answers): |
| 102 | + answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n" |
72 | 103 |
|
73 | 104 | answers = str(state.get("answer"))
|
74 | 105 | relevant_links = str(state.get("relevant_links"))
|
75 | 106 | answer = {}
|
76 | 107 |
|
77 | 108 | merge_prompt = PromptTemplate(
|
78 | 109 | template=template_answer,
|
79 |
| - #input_variables=["context", "question"], |
80 |
| - #partial_variables={"format_instructions": format_instructions}, |
| 110 | + input_variables=["context", "question"], |
| 111 | + partial_variables={"format_instructions": format_instructions}, |
81 | 112 | )
|
82 | 113 |
|
83 |
| - #answer = merge_prompt.invoke({"question": user_prompt}) |
| 114 | + answer = merge_prompt.invoke({"question": user_prompt}) |
84 | 115 |
|
85 |
| - state.update({"relevant_links": "TODO"}) |
86 |
| - state.update({"answer": "TODO"}) |
| 116 | + state.update({"relevant_links": reduce(operator.ior, relevant_links, {})}) |
| 117 | + state.update({"answer": answer}) |
87 | 118 | return state
|
0 commit comments