-
Notifications
You must be signed in to change notification settings - Fork 281
[Ready] Add API Call & Example OPs #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
63d430a
add api call
drcege 6720da4
add call_api ops
drcege 8daa6e1
clean
drcege ef11951
minor update
drcege 5597d5c
more tests
drcege 4b6e769
update tests
drcege 835be22
Merge branch 'main' into dev/api_model
drcege 325a753
update prompts
drcege 4f04bdd
fix unittest
drcege 0adbdcd
update tests
drcege 0aa4069
add docs
drcege f007532
minor fix
drcege 9aa7390
Merge branch 'main' into dev/api_model
drcege ee4f461
add API processor
drcege 9bbfe47
Merge branch 'main' into dev/api_model
drcege b00b182
refine API processor
drcege b718de7
refine
drcege 4d1670f
fix bugs
drcege 9e11aa3
fix tests
drcege 347bc0f
refine tests
drcege File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import re | ||
from typing import Dict, Optional | ||
|
||
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
OP_NAME = 'calibrate_qa_mapper' | ||
|
||
|
||
# TODO: LLM-based inference. | ||
@UNFORKABLE.register_module(OP_NAME) | ||
@OPERATORS.register_module(OP_NAME) | ||
class CalibrateQAMapper(Mapper): | ||
""" | ||
Mapper to calibrate question-answer pairs based on reference text. | ||
""" | ||
|
||
# avoid leading whitespace | ||
DEFAULT_SYSTEM_PROMPT = ('请根据提供的【参考信息】对【问题】和【回答】进行校准,使其更加详细、准确。\n' | ||
'按照以下格式输出:\n' | ||
'【问题】\n' | ||
'校准后的问题\n' | ||
'【回答】\n' | ||
'校准后的回答') | ||
DEFAULT_INPUT_TEMPLATE = '{reference}\n{qa_pair}' | ||
DEFAULT_REFERENCE_TEMPLATE = '【参考信息】\n{}' | ||
DEFAULT_QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}' | ||
DEFAULT_OUTPUT_PATTERN = r'【问题】\s*(.*?)\s*【回答】\s*(.*)' | ||
|
||
def __init__(self, | ||
api_model: str = 'gpt-4o', | ||
*, | ||
api_url: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
response_path: Optional[str] = None, | ||
system_prompt: Optional[str] = None, | ||
input_template: Optional[str] = None, | ||
reference_template: Optional[str] = None, | ||
qa_pair_template: Optional[str] = None, | ||
output_pattern: Optional[str] = None, | ||
api_params: Optional[Dict] = None, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param api_model: API model name. | ||
:param api_url: API URL. Defaults to DJ_API_URL environment variable. | ||
:param api_key: API key. Defaults to DJ_API_KEY environment variable. | ||
:param response_path: Path to extract content from the API response. | ||
Defaults to 'choices.0.message.content'. | ||
:param system_prompt: System prompt for the calibration task. | ||
:param input_template: Template for building the model input. | ||
:param reference_template: Template for formatting the reference text. | ||
:param qa_pair_template: Template for formatting question-answer pairs. | ||
:param output_pattern: Regular expression for parsing model output. | ||
:param api_params: Extra parameters passed to the API call. | ||
:param kwargs: Extra keyword arguments. | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT | ||
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE | ||
self.reference_template = reference_template or \ | ||
self.DEFAULT_REFERENCE_TEMPLATE | ||
self.qa_pair_template = qa_pair_template or \ | ||
self.DEFAULT_QA_PAIR_TEMPLATE | ||
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN | ||
|
||
self.api_params = api_params or {} | ||
self.model_key = prepare_model(model_type='api', | ||
api_model=api_model, | ||
api_url=api_url, | ||
api_key=api_key, | ||
response_path=response_path) | ||
|
||
def build_input(self, sample): | ||
reference = self.reference_template.format(sample[self.text_key]) | ||
qa_pair = self.qa_pair_template.format(sample[self.query_key], | ||
sample[self.response_key]) | ||
input_prompt = self.input_template.format(reference=reference, | ||
qa_pair=qa_pair) | ||
return input_prompt | ||
|
||
def parse_output(self, raw_output): | ||
match = re.match(self.output_pattern, raw_output) | ||
if match: | ||
return match.group(1).strip(), match.group(2).strip() | ||
else: | ||
return None, None | ||
|
||
def process_single(self, sample=None, rank=None): | ||
client = get_model(self.model_key, rank=rank) | ||
|
||
messages = [{ | ||
'role': 'system', | ||
'content': self.system_prompt | ||
}, { | ||
'role': 'user', | ||
'content': self.build_input(sample) | ||
}] | ||
output = client(messages, **self.api_params) | ||
drcege marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
parsed_q, parsed_a = self.parse_output(output) | ||
drcege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if parsed_q: | ||
sample[self.query_key] = parsed_q | ||
if parsed_a: | ||
sample[self.response_key] = parsed_a | ||
|
||
return sample |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE | ||
from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper | ||
|
||
OP_NAME = 'calibrate_query_mapper' | ||
|
||
|
||
# TODO: LLM-based inference. | ||
@UNFORKABLE.register_module(OP_NAME) | ||
@OPERATORS.register_module(OP_NAME) | ||
class CalibrateQueryMapper(CalibrateQAMapper): | ||
""" | ||
Mapper to calibrate query in question-answer pairs based on reference text. | ||
""" | ||
|
||
DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\ | ||
使其更加详细、准确,且仍可以由原答案回答。只输出校准后的问题,不要输出多余内容。' | ||
|
||
def parse_output(self, raw_output): | ||
return raw_output.strip(), None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE | ||
from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper | ||
|
||
OP_NAME = 'calibrate_response_mapper' | ||
|
||
|
||
# TODO: LLM-based inference. | ||
@UNFORKABLE.register_module(OP_NAME) | ||
@OPERATORS.register_module(OP_NAME) | ||
class CalibrateResponseMapper(CalibrateQAMapper): | ||
""" | ||
Mapper to calibrate response in question-answer pairs based on reference text. | ||
""" # noqa: E501 | ||
|
||
DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\ | ||
使其更加详细、准确,且仍可以回答原问题。只输出校准后的回答,不要输出多余内容。' | ||
|
||
def parse_output(self, raw_output): | ||
return None, raw_output.strip() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.