Skip to content

Commit fb98c56

Browse files
BeachWangdrcegeHYLcoolyxdyc
authored
Dev/manage meta (#518)
* - add insight mining * meta tags aggregator * naive reverse grouper * * resolve the bugs when running insight mining in multiprocessing mode * * update unittests * * update unittests * * update unittests * tags specified field * * update readme for analyzer * doc done * * use more detailed key * + add reference * move mm tags * move meta key * done * test done * rm nested set * Update constant.py minor fix * rename agg to batch meta * export in naive reverse grouper --------- Co-authored-by: null <3213204+drcege@users.noreply.github.com> Co-authored-by: gece.gc <gece.gc@alibaba-inc.com> Co-authored-by: lielin.hyl <lielin.hyl@alibaba-inc.com> Co-authored-by: Daoyuan Chen <67475544+yxdyc@users.noreply.github.com>
1 parent 1fe821f commit fb98c56

File tree

54 files changed

+1230
-627
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1230
-627
lines changed

configs/config_all.yaml

Lines changed: 46 additions & 28 deletions
Large diffs are not rendered by default.

data_juicer/ops/aggregator/entity_attribute_aggregator.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from data_juicer.ops.base_op import OPERATORS, Aggregator
88
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
9-
is_string_list, nested_access,
10-
nested_set)
9+
is_string_list)
10+
from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys
1111
from data_juicer.utils.model_utils import get_model, prepare_model
1212

1313
from .nested_aggregator import NestedAggregator
@@ -53,8 +53,8 @@ def __init__(self,
5353
api_model: str = 'gpt-4o',
5454
entity: str = None,
5555
attribute: str = None,
56-
input_key: str = None,
57-
output_key: str = None,
56+
input_key: str = MetaKeys.event_description,
57+
output_key: str = BatchMetaKeys.entity_attribute,
5858
word_limit: PositiveInt = 100,
5959
max_token_num: Optional[PositiveInt] = None,
6060
*,
@@ -73,12 +73,10 @@ def __init__(self,
7373
:param api_model: API model name.
7474
:param entity: The given entity.
7575
:param attribute: The given attribute.
76-
:param input_key: The input field key in the samples. Support for
77-
nested keys such as "__dj__stats__.text_len". It is text_key
78-
in default.
79-
:param output_key: The output field key in the samples. Support for
80-
nested keys such as "__dj__stats__.text_len". It is same as the
81-
input_key in default.
76+
:param input_key: The input key in the meta field of the samples.
77+
It is "event_description" in default.
78+
:param output_key: The output key in the aggregation field of the
79+
samples. It is "entity_attribute" in default.
8280
:param word_limit: Prompt the output length.
8381
:param max_token_num: The max token num of the total tokens of the
8482
sub documents. Without limitation if it is None.
@@ -103,8 +101,8 @@ def __init__(self,
103101

104102
self.entity = entity
105103
self.attribute = attribute
106-
self.input_key = input_key or self.text_key
107-
self.output_key = output_key or self.input_key
104+
self.input_key = input_key
105+
self.output_key = output_key
108106
self.word_limit = word_limit
109107
self.max_token_num = max_token_num
110108

@@ -131,7 +129,7 @@ def __init__(self,
131129
**model_params)
132130

133131
self.try_num = try_num
134-
self.nested_sum = NestedAggregator(model=api_model,
132+
self.nested_sum = NestedAggregator(api_model=api_model,
135133
max_token_num=max_token_num,
136134
api_endpoint=api_endpoint,
137135
response_path=response_path,
@@ -185,12 +183,21 @@ def attribute_summary(self, sub_docs, rank=None):
185183

186184
def process_single(self, sample=None, rank=None):
187185

186+
if self.output_key in sample[Fields.batch_meta]:
187+
return sample
188+
189+
if Fields.meta not in sample or self.input_key not in sample[
190+
Fields.meta][0]:
191+
logger.warning('The input key does not exist in the sample!')
192+
return sample
193+
194+
sub_docs = [d[self.input_key] for d in sample[Fields.meta]]
188195
# if not batched sample
189-
sub_docs = nested_access(sample, self.input_key)
190196
if not is_string_list(sub_docs):
197+
logger.warning('Require string meta as input!')
191198
return sample
192199

193-
sample = nested_set(sample, self.output_key,
194-
self.attribute_summary(sub_docs, rank=rank))
200+
sample[Fields.batch_meta][self.output_key] = self.attribute_summary(
201+
sub_docs, rank=rank)
195202

196203
return sample

data_juicer/ops/aggregator/most_relavant_entities_aggregator.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from pydantic import PositiveInt
66

77
from data_juicer.ops.base_op import OPERATORS, Aggregator
8-
from data_juicer.utils.common_utils import (is_string_list, nested_access,
9-
nested_set)
8+
from data_juicer.utils.common_utils import is_string_list
9+
from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys
1010
from data_juicer.utils.model_utils import get_model, prepare_model
1111

1212
from ..common import split_text_by_punctuation
@@ -44,8 +44,8 @@ def __init__(self,
4444
api_model: str = 'gpt-4o',
4545
entity: str = None,
4646
query_entity_type: str = None,
47-
input_key: str = None,
48-
output_key: str = None,
47+
input_key: str = MetaKeys.event_description,
48+
output_key: str = BatchMetaKeys.most_relavant_entities,
4949
max_token_num: Optional[PositiveInt] = None,
5050
*,
5151
api_endpoint: Optional[str] = None,
@@ -62,12 +62,10 @@ def __init__(self,
6262
:param api_model: API model name.
6363
:param entity: The given entity.
6464
:param query_entity_type: The type of queried relavant entities.
65-
:param input_key: The input field key in the samples. Support for
66-
nested keys such as "__dj__stats__.text_len". It is text_key
67-
in default.
68-
:param output_key: The output field key in the samples. Support for
69-
nested keys such as "__dj__stats__.text_len". It is same as the
70-
input_key in default.
65+
:param input_key: The input key in the meta field of the samples.
66+
It is "event_description" in default.
67+
:param output_key: The output key in the aggregation field of the
68+
samples. It is "most_relavant_entities" in default.
7169
:param max_token_num: The max token num of the total tokens of the
7270
sub documents. Without limitation if it is None.
7371
:param api_endpoint: URL endpoint for the API.
@@ -91,8 +89,8 @@ def __init__(self,
9189

9290
self.entity = entity
9391
self.query_entity_type = query_entity_type
94-
self.input_key = input_key or self.text_key
95-
self.output_key = output_key or self.input_key
92+
self.input_key = input_key
93+
self.output_key = output_key
9694
self.max_token_num = max_token_num
9795

9896
system_prompt_template = system_prompt_template or \
@@ -167,13 +165,22 @@ def query_most_relavant_entities(self, sub_docs, rank=None):
167165

168166
def process_single(self, sample=None, rank=None):
169167

168+
if self.output_key in sample[Fields.batch_meta]:
169+
return sample
170+
171+
if Fields.meta not in sample or self.input_key not in sample[
172+
Fields.meta][0]:
173+
logger.warning('The input key does not exist in the sample!')
174+
return sample
175+
176+
sub_docs = [d[self.input_key] for d in sample[Fields.meta]]
177+
170178
# if not batched sample
171-
sub_docs = nested_access(sample, self.input_key)
172179
if not is_string_list(sub_docs):
173180
return sample
174181

175-
sample = nested_set(
176-
sample, self.output_key,
177-
self.query_most_relavant_entities(sub_docs, rank=rank))
182+
sample[Fields.batch_meta][
183+
self.output_key] = self.query_most_relavant_entities(sub_docs,
184+
rank=rank)
178185

179186
return sample

data_juicer/ops/aggregator/nested_aggregator.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from data_juicer.ops.base_op import OPERATORS, Aggregator
77
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
8-
is_string_list, nested_access)
8+
is_string_list)
9+
from data_juicer.utils.constant import Fields, MetaKeys
910
from data_juicer.utils.model_utils import get_model, prepare_model
1011

1112
OP_NAME = 'nested_aggregator'
@@ -47,7 +48,7 @@ class NestedAggregator(Aggregator):
4748

4849
def __init__(self,
4950
api_model: str = 'gpt-4o',
50-
input_key: str = None,
51+
input_key: str = MetaKeys.event_description,
5152
output_key: str = None,
5253
max_token_num: Optional[PositiveInt] = None,
5354
*,
@@ -63,12 +64,10 @@ def __init__(self,
6364
"""
6465
Initialization method.
6566
:param api_model: API model name.
66-
:param input_key: The input field key in the samples. Support for
67-
nested keys such as "__dj__stats__.text_len". It is text_key
68-
in default.
69-
:param output_key: The output field key in the samples. Support for
70-
nested keys such as "__dj__stats__.text_len". It is same as the
71-
input_key in default.
67+
:param input_key: The input key in the meta field of the samples.
68+
It is "event_description" in default.
69+
:param output_key: The output key in the aggregation field in the
70+
samples. It is same as the input_key in default.
7271
:param max_token_num: The max token num of the total tokens of the
7372
sub documents. Without limitation if it is None.
7473
:param api_endpoint: URL endpoint for the API.
@@ -165,11 +164,21 @@ def recursive_summary(self, sub_docs, rank=None):
165164

166165
def process_single(self, sample=None, rank=None):
167166

167+
if self.output_key in sample[Fields.batch_meta]:
168+
return sample
169+
170+
if Fields.meta not in sample or self.input_key not in sample[
171+
Fields.meta][0]:
172+
logger.warning('The input key does not exist in the sample!')
173+
return sample
174+
175+
sub_docs = [d[self.input_key] for d in sample[Fields.meta]]
176+
168177
# if not batched sample
169-
sub_docs = nested_access(sample, self.input_key)
170178
if not is_string_list(sub_docs):
171179
return sample
172180

173-
sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank)
181+
sample[Fields.batch_meta][self.output_key] = self.recursive_summary(
182+
sub_docs, rank=rank)
174183

175184
return sample

data_juicer/ops/base_op.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,17 @@ def process_single(self, sample):
633633

634634
def run(self, dataset, *, exporter=None, tracer=None):
635635
dataset = super(Aggregator, self).run(dataset)
636+
# add batched meta field for OPs that produce aggregations
637+
if Fields.batch_meta not in dataset.features:
638+
from data_juicer.core.data import add_same_content_to_new_column
639+
dataset = dataset.map(add_same_content_to_new_column,
640+
fn_kwargs={
641+
'new_column_name': Fields.batch_meta,
642+
'initial_value': {}
643+
},
644+
num_proc=self.runtime_np(),
645+
batch_size=self.batch_size,
646+
desc='Adding new column for aggregation')
636647
new_dataset = dataset.map(
637648
self.process,
638649
num_proc=self.runtime_np(),

data_juicer/ops/filter/video_tagging_from_frames_filter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from pydantic import PositiveInt
55

6-
from data_juicer.utils.constant import Fields
6+
from data_juicer.utils.constant import Fields, MetaKeys
77

88
from ..base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE,
99
Filter)
@@ -30,7 +30,7 @@ def __init__(self,
3030
contain: str = 'any',
3131
frame_sampling_method: str = 'all_keyframes',
3232
frame_num: PositiveInt = 3,
33-
tag_field_name: str = Fields.video_frame_tags,
33+
tag_field_name: str = MetaKeys.video_frame_tags,
3434
any_or_all: str = 'any',
3535
*args,
3636
**kwargs):
@@ -55,8 +55,8 @@ def __init__(self,
5555
the first and the last frames will be extracted. If it's larger
5656
than 2, in addition to the first and the last frames, other frames
5757
will be extracted uniformly within the video duration.
58-
:param tag_field_name: the field name to store the tags. It's
59-
"__dj__video_frame_tags__" in default.
58+
:param tag_field_name: the key name to store the tags in the meta
59+
field. It's "video_frame_tags" in default.
6060
:param any_or_all: keep this sample with 'any' or 'all' strategy of
6161
all videos. 'any': keep this sample if any videos meet the
6262
condition. 'all': keep this sample only if all videos meet the
Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,48 @@
1+
import json
2+
import os
3+
4+
from data_juicer.utils.constant import Fields
5+
from data_juicer.utils.file_utils import create_directory_if_not_exists
6+
17
from ..base_op import OPERATORS, Grouper, convert_dict_list_to_list_dict
28

39

410
@OPERATORS.register_module('naive_reverse_grouper')
511
class NaiveReverseGrouper(Grouper):
612
"""Split batched samples to samples. """
713

8-
def __init__(self, *args, **kwargs):
14+
def __init__(self, batch_meta_export_path=None, *args, **kwargs):
915
"""
1016
Initialization method.
1117
18+
:param batch_meta_export_path: the path to export the batch meta.
19+
Just drop the batch meta if it is None.
1220
:param args: extra args
1321
:param kwargs: extra args
1422
"""
1523
super().__init__(*args, **kwargs)
24+
self.batch_meta_export_path = batch_meta_export_path
1625

1726
def process(self, dataset):
1827

1928
if len(dataset) == 0:
2029
return dataset
2130

2231
samples = []
32+
batch_metas = []
2333
for sample in dataset:
34+
if Fields.batch_meta in sample:
35+
batch_metas.append(sample[Fields.batch_meta])
36+
sample = {
37+
k: sample[k]
38+
for k in sample if k != Fields.batch_meta
39+
}
2440
samples.extend(convert_dict_list_to_list_dict(sample))
41+
if self.batch_meta_export_path is not None:
42+
create_directory_if_not_exists(
43+
os.path.dirname(self.batch_meta_export_path))
44+
with open(self.batch_meta_export_path, 'w') as f:
45+
for batch_meta in batch_metas:
46+
f.write(json.dumps(batch_meta, ensure_ascii=False) + '\n')
2547

2648
return samples

data_juicer/ops/mapper/dialog_intent_detection_mapper.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,21 @@
44
from loguru import logger
55
from pydantic import NonNegativeInt, PositiveInt
66

7-
from data_juicer.ops.base_op import OPERATORS, Mapper
8-
from data_juicer.utils.common_utils import nested_set
7+
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
98
from data_juicer.utils.constant import Fields, MetaKeys
109
from data_juicer.utils.model_utils import get_model, prepare_model
1110

1211
OP_NAME = 'dialog_intent_detection_mapper'
1312

1413

1514
# TODO: LLM-based inference.
15+
@TAGGING_OPS.register_module(OP_NAME)
1616
@OPERATORS.register_module(OP_NAME)
1717
class DialogIntentDetectionMapper(Mapper):
1818
"""
1919
Mapper to generate user's intent labels in dialog. Input from
2020
history_key, query_key and response_key. Output lists of
21-
labels and analysis for queries in the dialog, which is
22-
store in 'dialog_intent_labels' and
23-
'dialog_intent_labels_analysis' in Data-Juicer meta field.
21+
labels and analysis for queries in the dialog.
2422
"""
2523

2624
DEFAULT_SYSTEM_PROMPT = (
@@ -60,6 +58,8 @@ def __init__(self,
6058
intent_candidates: Optional[List[str]] = None,
6159
max_round: NonNegativeInt = 10,
6260
*,
61+
labels_key: str = MetaKeys.dialog_intent_labels,
62+
analysis_key: str = MetaKeys.dialog_intent_labels_analysis,
6363
api_endpoint: Optional[str] = None,
6464
response_path: Optional[str] = None,
6565
system_prompt: Optional[str] = None,
@@ -82,6 +82,11 @@ def __init__(self,
8282
intent labels of the open domain if it is None.
8383
:param max_round: The max num of round in the dialog to build the
8484
prompt.
85+
:param labels_key: The key name in the meta field to store the
86+
output labels. It is 'dialog_intent_labels' in default.
87+
:param analysis_key: The key name in the meta field to store the
88+
corresponding analysis. It is 'dialog_intent_labels_analysis'
89+
in default.
8590
:param api_endpoint: URL endpoint for the API.
8691
:param response_path: Path to extract content from the API response.
8792
Defaults to 'choices.0.message.content'.
@@ -111,6 +116,8 @@ def __init__(self,
111116

112117
self.intent_candidates = intent_candidates
113118
self.max_round = max_round
119+
self.labels_key = labels_key
120+
self.analysis_key = analysis_key
114121

115122
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
116123
self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
@@ -167,6 +174,11 @@ def parse_output(self, response):
167174
return analysis, labels
168175

169176
def process_single(self, sample, rank=None):
177+
178+
meta = sample[Fields.meta]
179+
if self.labels_key in meta and self.analysis_key in meta:
180+
return sample
181+
170182
client = get_model(self.model_key, rank=rank)
171183

172184
analysis_list = []
@@ -208,9 +220,7 @@ def process_single(self, sample, rank=None):
208220
history.append(self.labels_template.format(labels=labels))
209221
history.append(self.response_template.format(response=qa[1]))
210222

211-
analysis_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels_analysis}' # noqa: E501
212-
sample = nested_set(sample, analysis_key, analysis_list)
213-
labels_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels}'
214-
sample = nested_set(sample, labels_key, labels_list)
223+
meta[self.labels_key] = labels_list
224+
meta[self.analysis_key] = analysis_list
215225

216226
return sample

0 commit comments

Comments
 (0)