From a0d41677babfa7d288d7911d70405e2f67eea0ca Mon Sep 17 00:00:00 2001 From: Adam Dougal Date: Fri, 14 Jun 2024 17:03:19 +0100 Subject: [PATCH] feat: Include citations in response when using prompt flow - This enabled me to remove a bunch of citation parsing from the flow - This also enabled me to remove the citation output files - Added verbose output for tests to aid debugging - Fix issue with pfazure not being installed Required by https://github.com/Azure-Samples/chat-with-your-data-solution-accelerator/issues/406 --- Makefile | 2 +- .../utilities/orchestrator/prompt_flow.py | 24 +++++++++- .../orchestrator/test_prompt_flow.py | 22 +++++++-- infra/prompt-flow/create-prompt-flow.sh | 4 +- infra/prompt-flow/cwyd/answer_output.py | 6 --- infra/prompt-flow/cwyd/citation_output.py | 6 --- infra/prompt-flow/cwyd/flow.dag.template.yaml | 29 +----------- .../cwyd/generate_prompt_context.py | 18 +++---- infra/prompt-flow/cwyd/output_parser.py | 47 ------------------- 9 files changed, 53 insertions(+), 105 deletions(-) delete mode 100644 infra/prompt-flow/cwyd/answer_output.py delete mode 100644 infra/prompt-flow/cwyd/citation_output.py delete mode 100644 infra/prompt-flow/cwyd/output_parser.py diff --git a/Makefile b/Makefile index 77880d882..24c047929 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ python-test: ## 🧪 Run Python unit + functional tests unittest: ## 🧪 Run the unit tests @echo -e "\e[34m$@\e[0m" || true - @poetry run pytest -m "not azure and not functional" $(optional_args) + @poetry run pytest -vvv -m "not azure and not functional" $(optional_args) unittest-frontend: build-frontend ## 🧪 Unit test the Frontend webapp @echo -e "\e[34m$@\e[0m" || true diff --git a/code/backend/batch/utilities/orchestrator/prompt_flow.py b/code/backend/batch/utilities/orchestrator/prompt_flow.py index 15cf1a518..4f6cb85d2 100644 --- a/code/backend/batch/utilities/orchestrator/prompt_flow.py +++ b/code/backend/batch/utilities/orchestrator/prompt_flow.py @@ -5,6 +5,7 @@ from .orchestrator_base import OrchestratorBase from ..common.answer import Answer +from ..common.source_document import SourceDocument from ..helpers.llm_helper import LLMHelper from ..helpers.env_helper import EnvHelper @@ -50,7 +51,13 @@ async def orchestrate( raise RuntimeError(f"The request failed: {error}") from error # Transform response into answer for further processing - answer = Answer(question=user_message, answer=result["chat_output"]) + answer = Answer( + question=user_message, + answer=result["chat_output"], + source_documents=self.transform_citations_into_source_documents( + result["citations"] + ), + ) # Call Content Safety tool on answer if self.config.prompts.enable_content_safety: @@ -91,3 +98,18 @@ def transform_data_into_file(self, user_message, chat_history): with tempfile.NamedTemporaryFile(delete=False) as file: file.write(body) return file.name + + def transform_citations_into_source_documents(self, citations): + source_documents = [] + + for _, doc_id in enumerate(citations): + citation = citations[doc_id] + source_documents.append( + SourceDocument( + id=doc_id, + content=citation.get("content"), + source=citation.get("filepath"), + chunk_id=str(citation.get("chunk_id", 0)), + ) + ) + return source_documents diff --git a/code/tests/utilities/orchestrator/test_prompt_flow.py b/code/tests/utilities/orchestrator/test_prompt_flow.py index 902b59d0b..6b2025c72 100644 --- a/code/tests/utilities/orchestrator/test_prompt_flow.py +++ b/code/tests/utilities/orchestrator/test_prompt_flow.py @@ -95,16 +95,30 @@ async def test_orchestrate_returns_expected_chat_response( expected_result = [ { "role": "tool", - "content": '{"citations": [], "intent": "question"}', + "content": '{"citations": [{"content": "[None](some-filepath)\\n\\n\\nsome-content", "id": "[doc1]", "chunk_id": "1", "title": null, "filepath": "some-filepath", "url": "[None](some-filepath)", "metadata": {"offset": null, "source": "some-filepath", "markdown_url": "[None](some-filepath)", "title": null, "original_url": "some-filepath", "chunk": null, "key": "[doc1]", "filename": "some-filepath"}}, {"content": "[None](some-other-filepath)\\n\\n\\nsome-other-content", "id": "[doc2]", "chunk_id": "2", "title": null, "filepath": "some-other-filepath", "url": "[None](some-other-filepath)", "metadata": {"offset": null, "source": "some-other-filepath", "markdown_url": "[None](some-other-filepath)", "title": null, "original_url": "some-other-filepath", "chunk": null, "key": "[doc2]", "filename": "some-other-filepath"}}], "intent": "question"}', "end_turn": False, }, { "role": "assistant", - "content": "answer", + "content": "answer[doc1][doc2]", "end_turn": True, }, ] - chat_output = {"chat_output": "answer", "citations": ["", []]} + chat_output = { + "chat_output": "answer[doc1][doc2]", + "citations": { + "[doc1]": { + "content": "some-content", + "filepath": "some-filepath", + "chunk_id": 1, + }, + "[doc2]": { + "content": "some-other-content", + "filepath": "some-other-filepath", + "chunk_id": 2, + }, + }, + } orchestrator.transform_chat_history = MagicMock(return_value=[]) orchestrator.ml_client.online_endpoints.invoke = AsyncMock(return_value=chat_output) @@ -142,7 +156,7 @@ async def test_orchestrate_returns_content_safety_response_for_unsafe_output( ): # given user_message = "question" - chat_output = {"chat_output": "bad-response", "citations": ["", []]} + chat_output = {"chat_output": "bad-response", "citations": {}} content_safety_response = [ { "role": "tool", diff --git a/infra/prompt-flow/create-prompt-flow.sh b/infra/prompt-flow/create-prompt-flow.sh index 016e0d011..0e4d9d155 100755 --- a/infra/prompt-flow/create-prompt-flow.sh +++ b/infra/prompt-flow/create-prompt-flow.sh @@ -75,7 +75,7 @@ az account set --subscription "$subscription_id" set +e tries=1 -pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \ +poetry run pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \ --workspace-name "$aml_workspace" --flow "$flow_dir" --set type=chat while [ $? -ne 0 ]; do tries=$((tries+1)) @@ -86,7 +86,7 @@ while [ $? -ne 0 ]; do echo "Failed to create flow, will retry in 30 seconds" sleep 30 - pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \ + poetry run pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \ --workspace-name "$aml_workspace" --flow "$flow_dir" --set type=chat done set -e diff --git a/infra/prompt-flow/cwyd/answer_output.py b/infra/prompt-flow/cwyd/answer_output.py deleted file mode 100644 index 737fcaa8d..000000000 --- a/infra/prompt-flow/cwyd/answer_output.py +++ /dev/null @@ -1,6 +0,0 @@ -from promptflow import tool - - -@tool -def my_python_tool(output) -> str: - return output[0] diff --git a/infra/prompt-flow/cwyd/citation_output.py b/infra/prompt-flow/cwyd/citation_output.py deleted file mode 100644 index 2c436b645..000000000 --- a/infra/prompt-flow/cwyd/citation_output.py +++ /dev/null @@ -1,6 +0,0 @@ -from promptflow import tool - - -@tool -def my_python_tool(output) -> str: - return output[1] diff --git a/infra/prompt-flow/cwyd/flow.dag.template.yaml b/infra/prompt-flow/cwyd/flow.dag.template.yaml index 50e005c04..2fce37536 100755 --- a/infra/prompt-flow/cwyd/flow.dag.template.yaml +++ b/infra/prompt-flow/cwyd/flow.dag.template.yaml @@ -15,11 +15,11 @@ inputs: outputs: chat_output: type: string - reference: ${answer_output.output} + reference: ${chat_with_context.output} is_chat_output: true citations: type: string - reference: ${output_parser.output} + reference: ${generate_prompt_context.output} nodes: - name: lookup type: python @@ -95,31 +95,6 @@ nodes: api: chat module: promptflow.tools.aoai use_variants: false -- name: output_parser - type: python - source: - type: code - path: output_parser.py - inputs: - answer: ${chat_with_context.output} - sources: ${generate_prompt_context.output} - use_variants: false -- name: answer_output - type: python - source: - type: code - path: answer_output.py - inputs: - output: ${output_parser.output} - use_variants: false -- name: citation_output - type: python - source: - type: code - path: citation_output.py - inputs: - output: ${output_parser.output} - use_variants: false node_variants: {} environment: python_requirements_txt: requirements.txt diff --git a/infra/prompt-flow/cwyd/generate_prompt_context.py b/infra/prompt-flow/cwyd/generate_prompt_context.py index c7f47f451..1e7de4e56 100755 --- a/infra/prompt-flow/cwyd/generate_prompt_context.py +++ b/infra/prompt-flow/cwyd/generate_prompt_context.py @@ -5,7 +5,7 @@ @tool def generate_prompt_context(search_result: List[dict]) -> str: - retrieved_docs = [] + retrieved_docs = {} for index, item in enumerate(search_result): entity = SearchResultEntity.from_dict(item) @@ -14,14 +14,10 @@ def generate_prompt_context(search_result: List[dict]) -> str: filepath = additional_fields.get("source") chunk_id = additional_fields.get("chunk_id", additional_fields.get("chunk", "")) - retrieved_docs.append( - { - f"[doc{index+1}]": { - "content": content, - "filepath": filepath, - "chunk_id": chunk_id, - } - } - ) + retrieved_docs[f"[doc{index+1}]"] = { + "content": content, + "filepath": filepath, + "chunk_id": chunk_id, + } - return {"retrieved_documents": retrieved_docs} + return retrieved_docs diff --git a/infra/prompt-flow/cwyd/output_parser.py b/infra/prompt-flow/cwyd/output_parser.py deleted file mode 100644 index d516c6212..000000000 --- a/infra/prompt-flow/cwyd/output_parser.py +++ /dev/null @@ -1,47 +0,0 @@ -from promptflow import tool -import re - - -def _clean_up_answer(answer: str): - return answer.replace(" ", " ") - - -def _get_source_docs_from_answer(answer): - # extract all [docN] from answer and extract N, and just return the N's as a list of ints - results = re.findall(r"\[doc(\d+)\]", answer) - return [int(i) for i in results] - - -def _replace_last(text, old, new): - """Replaces the last occurence of a substring in a string - - This is done by reversing the string using [::-1], replacing the first occurence of the reversed substring, and - reversing the string again. - """ - return (text[::-1].replace(old[::-1], new[::-1], 1))[::-1] - - -def _make_doc_references_sequential(answer, doc_ids): - for i, idx in enumerate(doc_ids): - answer = _replace_last(answer, f"[doc{idx}]", f"[doc{i+1}]") - return answer - - -@tool -def my_python_tool(answer: str, sources: dict) -> str: - answer = _clean_up_answer(answer) - doc_ids = _get_source_docs_from_answer(answer) - answer = _make_doc_references_sequential(answer, doc_ids) - - source_documents = sources.get("retrieved_documents", []) - citations = [] - for i in doc_ids: - idx = i - 1 - - if idx >= len(source_documents): - continue - - doc = source_documents[idx] - citations.append(doc) - - return answer, citations