|
1 |
| -""" |
2 |
| -GenerateAnswerNode Module |
3 |
| -""" |
| 1 | +import requests |
| 2 | +import json |
4 | 3 | from typing import List, Optional
|
5 | 4 | from langchain.prompts import PromptTemplate
|
6 | 5 | from langchain_core.output_parsers import JsonOutputParser
|
|
10 | 9 | from tqdm import tqdm
|
11 | 10 | from ..utils.logging import get_logger
|
12 | 11 | from .base_node import BaseNode
|
13 |
| -from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD |
| 12 | +from ..prompts import ( |
| 13 | + TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, |
| 14 | + TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD |
| 15 | +) |
14 | 16 |
|
15 | 17 | class GenerateAnswerNode(BaseNode):
|
16 | 18 | """
|
@@ -39,110 +41,130 @@ def __init__(
|
39 | 41 | ):
|
40 | 42 | super().__init__(node_name, "node", input, output, 2, node_config)
|
41 | 43 |
|
42 |
| - self.llm_model = node_config["llm_model"] |
| 44 | + self.llm_model = node_config.get("llm_model") |
| 45 | + if isinstance(self.llm_model, ChatOllama): |
| 46 | + self.llm_model.format = "json" |
43 | 47 |
|
44 |
| - if isinstance(node_config["llm_model"], ChatOllama): |
45 |
| - self.llm_model.format="json" |
46 |
| - |
47 |
| - self.verbose = ( |
48 |
| - True if node_config is None else node_config.get("verbose", False) |
49 |
| - ) |
50 |
| - self.force = ( |
51 |
| - False if node_config is None else node_config.get("force", False) |
52 |
| - ) |
53 |
| - self.script_creator = ( |
54 |
| - False if node_config is None else node_config.get("script_creator", False) |
55 |
| - ) |
56 |
| - self.is_md_scraper = ( |
57 |
| - False if node_config is None else node_config.get("is_md_scraper", False) |
58 |
| - ) |
59 |
| - |
60 |
| - self.additional_info = node_config.get("additional_info") |
| 48 | + self.verbose = node_config.get("verbose", False) |
| 49 | + self.force = node_config.get("force", False) |
| 50 | + self.script_creator = node_config.get("script_creator", False) |
| 51 | + self.is_md_scraper = node_config.get("is_md_scraper", False) |
| 52 | + self.additional_info = node_config.get("additional_info", "") |
| 53 | + self.api_key = node_config.get("config", {}).get("llm", {}).get("api_key", "") |
61 | 54 |
|
62 | 55 | def execute(self, state: dict) -> dict:
|
63 |
| - """ |
64 |
| - Generates an answer by constructing a prompt from the user's input and the scraped |
65 |
| - content, querying the language model, and parsing its response. |
66 |
| -
|
67 |
| - Args: |
68 |
| - state (dict): The current state of the graph. The input keys will be used |
69 |
| - to fetch the correct data from the state. |
70 |
| -
|
71 |
| - Returns: |
72 |
| - dict: The updated state with the output key containing the generated answer. |
73 |
| -
|
74 |
| - Raises: |
75 |
| - KeyError: If the input keys are not found in the state, indicating |
76 |
| - that the necessary information for generating an answer is missing. |
77 |
| - """ |
78 |
| - |
79 | 56 | self.logger.info(f"--- Executing {self.node_name} Node ---")
|
80 | 57 |
|
81 |
| - # Interpret input keys based on the provided input expression |
82 | 58 | input_keys = self.get_input_keys(state)
|
83 |
| - # Fetching data from the state based on the input keys |
84 |
| - input_data = [state[key] for key in input_keys] |
85 |
| - user_prompt = input_data[0] |
86 |
| - doc = input_data[1] |
87 |
| - |
88 |
| - # Initialize the output parser |
89 |
| - if self.node_config.get("schema", None) is not None: |
90 |
| - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) |
91 |
| - else: |
92 |
| - output_parser = JsonOutputParser() |
| 59 | + user_prompt, doc = [state[key] for key in input_keys] |
93 | 60 |
|
| 61 | + schema = self.node_config.get("schema") |
| 62 | + output_parser = JsonOutputParser(pydantic_object=schema) if schema else JsonOutputParser() |
94 | 63 | format_instructions = output_parser.get_format_instructions()
|
95 | 64 |
|
96 |
| - if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper: |
97 |
| - template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD |
98 |
| - template_chunks_prompt = TEMPLATE_CHUNKS_MD |
99 |
| - template_merge_prompt = TEMPLATE_MERGE_MD |
100 |
| - else: |
101 |
| - template_no_chunks_prompt = TEMPLATE_NO_CHUNKS |
102 |
| - template_chunks_prompt = TEMPLATE_CHUNKS |
103 |
| - template_merge_prompt = TEMPLATE_MERGE |
104 |
| - |
105 |
| - if self.additional_info is not None: |
106 |
| - template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt |
107 |
| - template_chunks_prompt = self.additional_info + template_chunks_prompt |
108 |
| - template_merge_prompt = self.additional_info + template_merge_prompt |
109 |
| - |
110 |
| - if len(doc) == 1: |
111 |
| - prompt = PromptTemplate( |
112 |
| - template=template_no_chunks_prompt , |
113 |
| - input_variables=["question"], |
114 |
| - partial_variables={"context": doc, |
115 |
| - "format_instructions": format_instructions}) |
116 |
| - chain = prompt | self.llm_model | output_parser |
117 |
| - answer = chain.invoke({"question": user_prompt}) |
118 |
| - |
119 |
| - state.update({self.output[0]: answer}) |
| 65 | + if isinstance(self.llm_model, ChatOpenAI) and (not self.script_creator or self.force) or self.is_md_scraper: |
| 66 | + templates = { |
| 67 | + 'no_chunks': TEMPLATE_NO_CHUNKS_MD, |
| 68 | + 'chunks': TEMPLATE_CHUNKS_MD, |
| 69 | + 'merge': TEMPLATE_MERGE_MD |
| 70 | + } |
| 71 | + |
| 72 | + url = "https://api.openai.com/v1/chat/completions" |
| 73 | + headers = { |
| 74 | + "Content-Type": "application/json", |
| 75 | + "Authorization": f"Bearer {self.api_key}" |
| 76 | + } |
| 77 | + |
| 78 | + if len(doc) == 1: |
| 79 | + prompt = templates['no_chunks'].format( |
| 80 | + question=user_prompt, |
| 81 | + context=doc[0], |
| 82 | + format_instructions=format_instructions |
| 83 | + ) |
| 84 | + response = requests.post(url, headers=headers, json={ |
| 85 | + "model": self.llm_model.model_name, |
| 86 | + "messages": [{"role": "user", "content": prompt}], |
| 87 | + "temperature": 0 |
| 88 | + }, timeout=10) |
| 89 | + response_text = response.json()['choices'][0]['message']['content'] |
| 90 | + cleaned_response = json.loads(response_text.replace('\\n', '').replace('\\', '')) |
| 91 | + state.update({self.output[0]: cleaned_response}) |
| 92 | + return state |
| 93 | + |
| 94 | + chunks_responses = [] |
| 95 | + for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)): |
| 96 | + prompt = templates['chunks'].format( |
| 97 | + question=user_prompt, |
| 98 | + context=chunk, |
| 99 | + chunk_id=i + 1, |
| 100 | + format_instructions=format_instructions |
| 101 | + ) |
| 102 | + response = requests.post(url, headers=headers, json={ |
| 103 | + "model": self.llm_model.model_name, |
| 104 | + "messages": [{"role": "user", "content": prompt}], |
| 105 | + "temperature": 0 |
| 106 | + }, timeout=10) |
| 107 | + chunk_response = response.json()['choices'][0]['message']['content'] |
| 108 | + cleaned_chunk_response = json.loads(chunk_response.replace('\\n', '').replace('\\', '')) |
| 109 | + chunks_responses.append(cleaned_chunk_response) |
| 110 | + |
| 111 | + merge_context = " ".join([json.dumps(chunk) for chunk in chunks_responses]) |
| 112 | + merge_prompt = templates['merge'].format( |
| 113 | + question=user_prompt, |
| 114 | + context=merge_context, |
| 115 | + format_instructions=format_instructions |
| 116 | + ) |
| 117 | + response = requests.post(url, headers=headers, json={ |
| 118 | + "model": self.llm_model.model_name, |
| 119 | + "messages": [{"role": "user", "content": merge_prompt}], |
| 120 | + "temperature": 0 |
| 121 | + }, timeout=10) |
| 122 | + response_text = response.json()['choices'][0]['message']['content'] |
| 123 | + cleaned_response = json.loads(response_text.replace('\\n', '').replace('\\', '')) |
| 124 | + state.update({self.output[0]: cleaned_response}) |
120 | 125 | return state
|
121 | 126 |
|
122 |
| - chains_dict = {} |
123 |
| - for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)): |
124 |
| - |
125 |
| - prompt = PromptTemplate( |
126 |
| - template=TEMPLATE_CHUNKS, |
127 |
| - input_variables=["question"], |
128 |
| - partial_variables={"context": chunk, |
129 |
| - "chunk_id": i + 1, |
130 |
| - "format_instructions": format_instructions}) |
131 |
| - chain_name = f"chunk{i+1}" |
132 |
| - chains_dict[chain_name] = prompt | self.llm_model | output_parser |
133 |
| - |
134 |
| - async_runner = RunnableParallel(**chains_dict) |
135 |
| - |
136 |
| - batch_results = async_runner.invoke({"question": user_prompt}) |
137 |
| - |
138 |
| - merge_prompt = PromptTemplate( |
139 |
| - template = template_merge_prompt , |
| 127 | + else: |
| 128 | + templates = { |
| 129 | + 'no_chunks': TEMPLATE_NO_CHUNKS, |
| 130 | + 'chunks': TEMPLATE_CHUNKS, |
| 131 | + 'merge': TEMPLATE_MERGE |
| 132 | + } |
| 133 | + |
| 134 | + if self.additional_info: |
| 135 | + templates = {key: self.additional_info + template for key, template in templates.items()} |
| 136 | + |
| 137 | + if len(doc) == 1: |
| 138 | + prompt = PromptTemplate( |
| 139 | + template=templates['no_chunks'], |
| 140 | + input_variables=["question"], |
| 141 | + partial_variables={"context": doc, "format_instructions": format_instructions} |
| 142 | + ) |
| 143 | + chain = prompt | self.llm_model | output_parser |
| 144 | + answer = chain.invoke({"question": user_prompt}) |
| 145 | + state.update({self.output[0]: answer}) |
| 146 | + return state |
| 147 | + |
| 148 | + chains_dict = {} |
| 149 | + for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)): |
| 150 | + prompt = PromptTemplate( |
| 151 | + template=templates['chunks'], |
| 152 | + input_variables=["question"], |
| 153 | + partial_variables={"context": chunk, "chunk_id": i + 1, "format_instructions": format_instructions} |
| 154 | + ) |
| 155 | + chain_name = f"chunk{i+1}" |
| 156 | + chains_dict[chain_name] = prompt | self.llm_model | output_parser |
| 157 | + |
| 158 | + async_runner = RunnableParallel(**chains_dict) |
| 159 | + batch_results = async_runner.invoke({"question": user_prompt}) |
| 160 | + |
| 161 | + merge_prompt = PromptTemplate( |
| 162 | + template=templates['merge'], |
140 | 163 | input_variables=["context", "question"],
|
141 |
| - partial_variables={"format_instructions": format_instructions}, |
| 164 | + partial_variables={"format_instructions": format_instructions} |
142 | 165 | )
|
| 166 | + merge_chain = merge_prompt | self.llm_model | output_parser |
| 167 | + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) |
143 | 168 |
|
144 |
| - merge_chain = merge_prompt | self.llm_model | output_parser |
145 |
| - answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) |
146 |
| - |
147 |
| - state.update({self.output[0]: answer}) |
148 |
| - return state |
| 169 | + state.update({self.output[0]: answer}) |
| 170 | + return state |
0 commit comments