Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
image_captioning_from_gpt4v_mapper, image_captioning_mapper,
image_diffusion_mapper, image_face_blur_mapper,
image_tagging_mapper, nlpaug_en_mapper, nlpcda_zh_mapper,
optimize_query_mapper, punctuation_normalization_mapper,
optimize_qa_mapper, optimize_query_mapper,
optimize_response_mapper, punctuation_normalization_mapper,
remove_bibliography_mapper, remove_comments_mapper,
remove_header_mapper, remove_long_words_mapper,
remove_non_chinese_character_mapper,
Expand Down Expand Up @@ -45,7 +46,9 @@
from .image_tagging_mapper import ImageTaggingMapper
from .nlpaug_en_mapper import NlpaugEnMapper
from .nlpcda_zh_mapper import NlpcdaZhMapper
from .optimize_qa_mapper import OptimizeQAMapper
from .optimize_query_mapper import OptimizeQueryMapper
from .optimize_response_mapper import OptimizeResponseMapper
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
from .remove_comments_mapper import RemoveCommentsMapper
Expand Down Expand Up @@ -87,7 +90,8 @@
'GenerateQAFromTextMapper', 'ImageBlurMapper',
'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper',
'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper',
'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQueryMapper',
'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper',
'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PunctuationNormalizationMapper', 'RemoveBibliographyMapper',
'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
Expand Down
90 changes: 22 additions & 68 deletions data_juicer/ops/mapper/optimize_query_mapper.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
from typing import Dict, Optional

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
from data_juicer.ops.mapper import OptimizeQAMapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

torch = LazyLoader('torch', 'torch')
vllm = LazyLoader('vllm', 'vllm')

DEFAULT_SYSTEM_PROMPT = '请优化这个指令,将其修改为一个更详细具体的指令。'

OP_NAME = 'optimize_query_mapper'


# TODO: Extend LLM-based OPs into API-based implementation.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class OptimizeQueryMapper(Mapper):
"""Mapper to optimize instruction query.
Recommended model list: [
alibaba-pai/Qwen2-1.5B-Instruct-Refine
alibaba-pai/Qwen2-7B-Instruct-Refine
]
class OptimizeQueryMapper(OptimizeQAMapper):
"""
Mapper to optimize only query in question-answer pairs.
"""

DEFAULT_SYSTEM_PROMPT = '优化问答对中的问题,将其更加详细具体,但仍可以由原答案回答。只输出优化后的问题,不要输出多余内容。'

_accelerator = 'cuda'

def __init__(self,
*,
hf_model: str = 'alibaba-pai/Qwen2-7B-Instruct-Refine',
trust_remote_code: bool = False,
system_prompt: Optional[str] = None,
Expand All @@ -35,7 +32,6 @@ def __init__(self,
max_model_len: Optional[int] = None,
max_num_seqs: int = 256,
sampling_params: Dict = {},
*args,
**kwargs):
"""
Initialization method.
Expand All @@ -56,60 +52,18 @@ def __init__(self,
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.num_proc = 1

if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
self.system_prompt = system_prompt
self.enable_vllm = enable_vllm

if enable_vllm:
assert torch.cuda.device_count() >= 1, 'must be executed in CUDA'
if not tensor_parallel_size:
tensor_parallel_size = torch.cuda.device_count()
logger.info(f'Set tensor_parallel_size to \
{tensor_parallel_size} for vllm.')
self.model_key = prepare_model(
model_type='vllm',
pretrained_model_name_or_path=hf_model,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs)
self.sampling_params = vllm.SamplingParams(**sampling_params)
else:
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_model,
trust_remote_code=trust_remote_code)
self.sampling_params = sampling_params

def process_single(self, sample=None, rank=None):
model, processor = get_model(self.model_key, rank=rank)

messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': sample[self.query_key]
}]
input_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)

if self.enable_vllm:
response = model.generate([input_prompt], self.sampling_params)
output = response[0].outputs[0].text
else:
inputs = processor(input_prompt,
return_tensors='pt').to(model.device)
response = model.generate(**inputs,
eos_token_id=processor.eos_token_id,
**self.sampling_params)
output = processor.decode(response.cpu()[0],
skip_special_tokens=True)
super().__init__(hf_model=hf_model,
trust_remote_code=trust_remote_code,
system_prompt=system_prompt,
enable_vllm=enable_vllm,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
sampling_params=sampling_params,
**kwargs)

sample[self.query_key] = output
def build_input(self, sample):
return sample[self.query_key]

return sample
def parse_output(self, raw_output):
return raw_output.strip(), None
Loading