|
2 | 2 | GenerateAnswerNodeKLevel Module
|
3 | 3 | """
|
4 | 4 | from typing import List, Optional
|
| 5 | +from langchain.prompts import PromptTemplate |
| 6 | +from tqdm import tqdm |
| 7 | +from langchain_core.output_parsers import JsonOutputParser |
| 8 | +from langchain_core.runnables import RunnableParallel |
| 9 | +from langchain_openai import ChatOpenAI, AzureChatOpenAI |
| 10 | +from langchain_mistralai import ChatMistralAI |
| 11 | +from langchain_aws import ChatBedrock |
| 12 | +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser |
5 | 13 | from .base_node import BaseNode
|
| 14 | +from ..prompts import ( |
| 15 | + TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, |
| 16 | + TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD |
| 17 | +) |
6 | 18 |
|
7 | 19 | class GenerateAnswerNodeKLevel(BaseNode):
|
8 | 20 | """
|
@@ -33,18 +45,92 @@ def __init__(
|
33 | 45 |
|
34 | 46 | self.llm_model = node_config["llm_model"]
|
35 | 47 | self.embedder_model = node_config.get("embedder_model", None)
|
36 |
| - self.verbose = ( |
37 |
| - False if node_config is None else node_config.get("verbose", False) |
38 |
| - ) |
| 48 | + self.verbose = node_config.get("verbose", False) |
| 49 | + self.force = node_config.get("force", False) |
| 50 | + self.script_creator = node_config.get("script_creator", False) |
| 51 | + self.is_md_scraper = node_config.get("is_md_scraper", False) |
| 52 | + self.additional_info = node_config.get("additional_info") |
39 | 53 |
|
40 | 54 | def execute(self, state: dict) -> dict:
|
| 55 | + input_keys = self.get_input_keys(state) |
| 56 | + input_data = [state[key] for key in input_keys] |
| 57 | + user_prompt = input_data[0] |
| 58 | + |
| 59 | + if self.node_config.get("schema", None) is not None: |
| 60 | + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): |
| 61 | + self.llm_model = self.llm_model.with_structured_output( |
| 62 | + schema=self.node_config["schema"] |
| 63 | + ) |
| 64 | + output_parser = get_structured_output_parser(self.node_config["schema"]) |
| 65 | + format_instructions = "NA" |
| 66 | + else: |
| 67 | + if not isinstance(self.llm_model, ChatBedrock): |
| 68 | + output_parser = get_pydantic_output_parser(self.node_config["schema"]) |
| 69 | + format_instructions = output_parser.get_format_instructions() |
| 70 | + else: |
| 71 | + output_parser = None |
| 72 | + format_instructions = "" |
| 73 | + else: |
| 74 | + if not isinstance(self.llm_model, ChatBedrock): |
| 75 | + output_parser = JsonOutputParser() |
| 76 | + format_instructions = output_parser.get_format_instructions() |
| 77 | + else: |
| 78 | + output_parser = None |
| 79 | + format_instructions = "" |
| 80 | + |
| 81 | + if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) \ |
| 82 | + and not self.script_creator \ |
| 83 | + or self.force \ |
| 84 | + and not self.script_creator or self.is_md_scraper: |
| 85 | + template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD |
| 86 | + template_chunks_prompt = TEMPLATE_CHUNKS_MD |
| 87 | + template_merge_prompt = TEMPLATE_MERGE_MD |
| 88 | + else: |
| 89 | + template_no_chunks_prompt = TEMPLATE_NO_CHUNKS |
| 90 | + template_chunks_prompt = TEMPLATE_CHUNKS |
| 91 | + template_merge_prompt = TEMPLATE_MERGE |
| 92 | + |
| 93 | + if self.additional_info is not None: |
| 94 | + template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt |
| 95 | + template_chunks_prompt = self.additional_info + template_chunks_prompt |
| 96 | + template_merge_prompt = self.additional_info + template_merge_prompt |
| 97 | + |
41 | 98 | client = state["vectorial_db"]
|
42 | 99 |
|
43 |
| - answer = client.query( |
44 |
| - collection_name="demo_collection", |
45 |
| - query_text="This is a query document" |
| 100 | + answer_db = client.query( |
| 101 | + collection_name="vectorial_collection", |
| 102 | + query_text= state["question"] |
46 | 103 | )
|
47 | 104 |
|
| 105 | + results_db = [elem for elem in state[answer_db]] |
| 106 | + |
| 107 | + chains_dict = {} |
| 108 | + for i, chunk in enumerate(tqdm(results_db, |
| 109 | + desc="Processing chunks", disable=not self.verbose)): |
| 110 | + prompt = PromptTemplate( |
| 111 | + template=template_chunks_prompt, |
| 112 | + input_variables=["question"], |
| 113 | + partial_variables={"context": chunk, |
| 114 | + "chunk_id": i + 1, |
| 115 | + } |
| 116 | + ) |
| 117 | + chain_name = f"chunk{i+1}" |
| 118 | + chains_dict[chain_name] = prompt | self.llm_model |
| 119 | + |
| 120 | + async_runner = RunnableParallel(**chains_dict) |
| 121 | + batch_results = async_runner.invoke({"question": user_prompt}) |
| 122 | + |
| 123 | + merge_prompt = PromptTemplate( |
| 124 | + template=template_merge_prompt, |
| 125 | + input_variables=["context", "question"], |
| 126 | + partial_variables={"format_instructions": format_instructions} |
| 127 | + ) |
| 128 | + |
| 129 | + merge_chain = merge_prompt | self.llm_model |
| 130 | + if output_parser: |
| 131 | + merge_chain = merge_chain | output_parser |
| 132 | + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) |
| 133 | + |
48 | 134 | state["answer"] = answer
|
49 | 135 |
|
50 | 136 | return state
|
0 commit comments