Skip to content

feat: Include citations in response when using prompt flow #1089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ python-test: ## 🧪 Run Python unit + functional tests

unittest: ## 🧪 Run the unit tests
@echo -e "\e[34m$@\e[0m" || true
@poetry run pytest -m "not azure and not functional" $(optional_args)
@poetry run pytest -vvv -m "not azure and not functional" $(optional_args)

unittest-frontend: build-frontend ## 🧪 Unit test the Frontend webapp
@echo -e "\e[34m$@\e[0m" || true
Expand Down
24 changes: 23 additions & 1 deletion code/backend/batch/utilities/orchestrator/prompt_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .orchestrator_base import OrchestratorBase
from ..common.answer import Answer
from ..common.source_document import SourceDocument
from ..helpers.llm_helper import LLMHelper
from ..helpers.env_helper import EnvHelper

Expand Down Expand Up @@ -50,7 +51,13 @@ async def orchestrate(
raise RuntimeError(f"The request failed: {error}") from error

# Transform response into answer for further processing
answer = Answer(question=user_message, answer=result["chat_output"])
answer = Answer(
question=user_message,
answer=result["chat_output"],
source_documents=self.transform_citations_into_source_documents(
result["citations"]
),
)

# Call Content Safety tool on answer
if self.config.prompts.enable_content_safety:
Expand Down Expand Up @@ -91,3 +98,18 @@ def transform_data_into_file(self, user_message, chat_history):
with tempfile.NamedTemporaryFile(delete=False) as file:
file.write(body)
return file.name

def transform_citations_into_source_documents(self, citations):
source_documents = []

for _, doc_id in enumerate(citations):
citation = citations[doc_id]
source_documents.append(
SourceDocument(
id=doc_id,
content=citation.get("content"),
source=citation.get("filepath"),
chunk_id=str(citation.get("chunk_id", 0)),
)
)
return source_documents
22 changes: 18 additions & 4 deletions code/tests/utilities/orchestrator/test_prompt_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,30 @@ async def test_orchestrate_returns_expected_chat_response(
expected_result = [
{
"role": "tool",
"content": '{"citations": [], "intent": "question"}',
"content": '{"citations": [{"content": "[None](some-filepath)\\n\\n\\nsome-content", "id": "[doc1]", "chunk_id": "1", "title": null, "filepath": "some-filepath", "url": "[None](some-filepath)", "metadata": {"offset": null, "source": "some-filepath", "markdown_url": "[None](some-filepath)", "title": null, "original_url": "some-filepath", "chunk": null, "key": "[doc1]", "filename": "some-filepath"}}, {"content": "[None](some-other-filepath)\\n\\n\\nsome-other-content", "id": "[doc2]", "chunk_id": "2", "title": null, "filepath": "some-other-filepath", "url": "[None](some-other-filepath)", "metadata": {"offset": null, "source": "some-other-filepath", "markdown_url": "[None](some-other-filepath)", "title": null, "original_url": "some-other-filepath", "chunk": null, "key": "[doc2]", "filename": "some-other-filepath"}}], "intent": "question"}',
"end_turn": False,
},
{
"role": "assistant",
"content": "answer",
"content": "answer[doc1][doc2]",
"end_turn": True,
},
]
chat_output = {"chat_output": "answer", "citations": ["", []]}
chat_output = {
"chat_output": "answer[doc1][doc2]",
"citations": {
"[doc1]": {
"content": "some-content",
"filepath": "some-filepath",
"chunk_id": 1,
},
"[doc2]": {
"content": "some-other-content",
"filepath": "some-other-filepath",
"chunk_id": 2,
},
},
}

orchestrator.transform_chat_history = MagicMock(return_value=[])
orchestrator.ml_client.online_endpoints.invoke = AsyncMock(return_value=chat_output)
Expand Down Expand Up @@ -142,7 +156,7 @@ async def test_orchestrate_returns_content_safety_response_for_unsafe_output(
):
# given
user_message = "question"
chat_output = {"chat_output": "bad-response", "citations": ["", []]}
chat_output = {"chat_output": "bad-response", "citations": {}}
content_safety_response = [
{
"role": "tool",
Expand Down
4 changes: 2 additions & 2 deletions infra/prompt-flow/create-prompt-flow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ az account set --subscription "$subscription_id"

set +e
tries=1
pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \
poetry run pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \
--workspace-name "$aml_workspace" --flow "$flow_dir" --set type=chat
while [ $? -ne 0 ]; do
tries=$((tries+1))
Expand All @@ -86,7 +86,7 @@ while [ $? -ne 0 ]; do

echo "Failed to create flow, will retry in 30 seconds"
sleep 30
pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \
poetry run pfazure flow create --subscription "$subscription_id" --resource-group "$resource_group" \
--workspace-name "$aml_workspace" --flow "$flow_dir" --set type=chat
done
set -e
Expand Down
6 changes: 0 additions & 6 deletions infra/prompt-flow/cwyd/answer_output.py

This file was deleted.

6 changes: 0 additions & 6 deletions infra/prompt-flow/cwyd/citation_output.py

This file was deleted.

29 changes: 2 additions & 27 deletions infra/prompt-flow/cwyd/flow.dag.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ inputs:
outputs:
chat_output:
type: string
reference: ${answer_output.output}
reference: ${chat_with_context.output}
is_chat_output: true
citations:
type: string
reference: ${output_parser.output}
reference: ${generate_prompt_context.output}
nodes:
- name: lookup
type: python
Expand Down Expand Up @@ -95,31 +95,6 @@ nodes:
api: chat
module: promptflow.tools.aoai
use_variants: false
- name: output_parser
type: python
source:
type: code
path: output_parser.py
inputs:
answer: ${chat_with_context.output}
sources: ${generate_prompt_context.output}
use_variants: false
- name: answer_output
type: python
source:
type: code
path: answer_output.py
inputs:
output: ${output_parser.output}
use_variants: false
- name: citation_output
type: python
source:
type: code
path: citation_output.py
inputs:
output: ${output_parser.output}
use_variants: false
node_variants: {}
environment:
python_requirements_txt: requirements.txt
18 changes: 7 additions & 11 deletions infra/prompt-flow/cwyd/generate_prompt_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@tool
def generate_prompt_context(search_result: List[dict]) -> str:
retrieved_docs = []
retrieved_docs = {}
for index, item in enumerate(search_result):

entity = SearchResultEntity.from_dict(item)
Expand All @@ -14,14 +14,10 @@ def generate_prompt_context(search_result: List[dict]) -> str:
filepath = additional_fields.get("source")
chunk_id = additional_fields.get("chunk_id", additional_fields.get("chunk", ""))

retrieved_docs.append(
{
f"[doc{index+1}]": {
"content": content,
"filepath": filepath,
"chunk_id": chunk_id,
}
}
)
retrieved_docs[f"[doc{index+1}]"] = {
"content": content,
"filepath": filepath,
"chunk_id": chunk_id,
}

return {"retrieved_documents": retrieved_docs}
return retrieved_docs
47 changes: 0 additions & 47 deletions infra/prompt-flow/cwyd/output_parser.py

This file was deleted.

Loading