Skip to content

Commit e90a759

Browse files
HYLcoolyxdyc
andauthored
Add general fused op (#626)
* + add general_fused_op * + add parameter descriptions for two ffmpeg wrapper ops in config_all.yaml * - remove lazy_loading for bs4 * Update op_fusion.py * * run pre-commit * * run pre-commit --------- Co-authored-by: Daoyuan Chen <67475544+yxdyc@users.noreply.github.com>
1 parent f9d5f93 commit e90a759

File tree

5 files changed

+242
-24
lines changed

5 files changed

+242
-24
lines changed

configs/config_all.yaml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,16 @@ hpo_config: null # path to a configur
7777
# process schedule: a list of several process operators with their arguments
7878
process:
7979
# Mapper ops. Most of these ops need no arguments.
80-
- audio_add_gaussian_noise_mapper: # Mapper to add Gaussian noise to audio.
81-
min_amplitude: 0.001 # Default: 0.001. Minimum noise amplification factor.
82-
max_amplitude: 0.015 # Default: 0.015. Maximum noise amplification factor.
83-
p: 0.5 # Default: 0.5.(range: [ 0.0, 1.0 ].) The probability of applying this transform.
80+
- audio_add_gaussian_noise_mapper: # Mapper to add Gaussian noise to audio.
81+
min_amplitude: 0.001 # Default: 0.001. Minimum noise amplification factor.
82+
max_amplitude: 0.015 # Default: 0.015. Maximum noise amplification factor.
83+
p: 0.5 # Default: 0.5.(range: [ 0.0, 1.0 ].) The probability of applying this transform.
8484
- audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters
85+
filter_name: null # ffmpeg audio filter name. e.g. 'atrim'.
86+
filter_kwargs: null # keyword-arguments passed to ffmpeg filter. e.g. {'end': 6}.
87+
global_args: null # list-arguments passed to ffmpeg command-line. e.g. ['-progress'].
88+
capture_stderr: true # whether to capture stderr.
89+
overwrite_output: true # whether to overwrite the output file.
8590
- calibrate_qa_mapper: # calibrate question-answer pairs based on reference text.
8691
api_model: 'gpt-4o' # API model name.
8792
api_endpoint: null # URL endpoint for the API.
@@ -543,6 +548,11 @@ process:
543548
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
544549
radius: 2 # radius of blur kernel
545550
- video_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg video filters
551+
filter_name: null # ffmpeg audio filter name. e.g. 'scale'.
552+
filter_kwargs: null # keyword-arguments passed to ffmpeg filter. e.g. {'width': 224, 'height': 224}.
553+
global_args: null # list-arguments passed to ffmpeg command-line. e.g. ['-progress'].
554+
capture_stderr: true # whether to capture stderr.
555+
overwrite_output: true # whether to overwrite the output file.
546556
- video_remove_watermark_mapper: # Remove the watermarks in videos given regions
547557
roi_strings: ['0,0,0.1,0.1'] # a given list of regions the watermarks locate. The format of each can be "x1, y1, x2, y2", "(x1, y1, x2, y2)", or "[x1, y1, x2, y2]".
548558
roi_type: ratio # the roi string type. When the type is 'pixel', (x1, y1), (x2, y2) are the locations of pixels in the top left corner and the bottom right corner respectively. If the roi_type is 'ratio', the coordinates are normalized by widths and heights.

data_juicer/ops/base_op.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ def run(self, dataset):
241241
if not isinstance(dataset, NestedDataset):
242242
dataset = NestedDataset(dataset)
243243
# add meta field for OPs that produce tags
244+
from data_juicer.core.data import add_same_content_to_new_column
244245
if self._name in TAGGING_OPS.modules \
245246
and Fields.meta not in dataset.features:
246-
from data_juicer.core.data import add_same_content_to_new_column
247247
dataset = dataset.map(add_same_content_to_new_column,
248248
fn_kwargs={
249249
'new_column_name': Fields.meta,
@@ -252,7 +252,20 @@ def run(self, dataset):
252252
num_proc=self.runtime_np(),
253253
batch_size=self.batch_size,
254254
desc='Adding new column for meta')
255-
if self.index_key is not None:
255+
# add stats field for Filters that produce stats
256+
if isinstance(self, Filter) \
257+
and self._name not in NON_STATS_FILTERS.modules \
258+
and Fields.stats not in dataset.features:
259+
dataset = dataset.map(add_same_content_to_new_column,
260+
fn_kwargs={
261+
'new_column_name': Fields.stats,
262+
'initial_value': {}
263+
},
264+
num_proc=self.runtime_np(),
265+
batch_size=self.batch_size,
266+
desc='Adding new column for stats')
267+
if self.index_key is not None \
268+
and self.index_key not in dataset.features:
256269

257270
def add_index(sample, idx):
258271
sample[self.index_key] = idx
@@ -455,18 +468,6 @@ def process_single(self, sample):
455468

456469
def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
457470
dataset = super(Filter, self).run(dataset)
458-
# add stats field for Filters that produce stats
459-
if self._name not in NON_STATS_FILTERS.modules \
460-
and Fields.stats not in dataset.features:
461-
from data_juicer.core.data import add_same_content_to_new_column
462-
dataset = dataset.map(add_same_content_to_new_column,
463-
fn_kwargs={
464-
'new_column_name': Fields.stats,
465-
'initial_value': {}
466-
},
467-
num_proc=self.runtime_np(),
468-
batch_size=self.batch_size,
469-
desc='Adding new column for stats')
470471
dataset = dataset.map(self.compute_stats,
471472
num_proc=self.runtime_np(),
472473
with_rank=self.use_cuda(),

data_juicer/ops/mapper/extract_tables_from_html_mapper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
import bs4
2+
13
from data_juicer.utils.constant import Fields, MetaKeys
2-
from data_juicer.utils.lazy_loader import LazyLoader
34

45
from ..base_op import OPERATORS, TAGGING_OPS, Mapper
56

6-
bs4 = LazyLoader('bs4', 'bs4')
7-
87
OP_NAME = 'extract_tables_from_html_mapper'
98

109

data_juicer/ops/op_fusion.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
from loguru import logger
55

6+
from data_juicer.ops.base_op import OP, OPERATORS, Filter, Mapper
7+
from data_juicer.ops.load import load_ops
68
from data_juicer.utils.constant import Fields, InterVars
79
from data_juicer.utils.registry import Registry
810

9-
from .base_op import Filter
10-
1111
# Type of intermediate vars
1212
# text
1313
INTER_LINES = Registry(InterVars.lines)
@@ -196,3 +196,75 @@ def process_batched(self, samples):
196196
else:
197197
res = this_res
198198
return res
199+
200+
201+
@OPERATORS.register_module('general_fused_op')
202+
class GeneralFusedOP(OP):
203+
"""An explicitly fused operator designed to execute multiple sequential
204+
operations (OPs) on the same batch, enabling fine-grained control over
205+
data processing."""
206+
207+
_batched_op = True
208+
209+
def __init__(self,
210+
batch_size: int = 1,
211+
fused_op_list: List = None,
212+
*args,
213+
**kwargs):
214+
super().__init__(*args, **kwargs)
215+
self.batch_size = batch_size
216+
if fused_op_list is None:
217+
fused_op_list = []
218+
self.fused_ops = load_ops(fused_op_list)
219+
self._name = 'GeneralFusedOP:(%s)' % ','.join(
220+
[op._name for op in self.fused_ops])
221+
# set accelerator to 'cuda' if there exists any ops whose accelerator
222+
# is 'cuda'
223+
accelerator_methods = set([op.accelerator for op in self.fused_ops])
224+
if 'cuda' in accelerator_methods:
225+
self.accelerator = 'cuda'
226+
227+
# update num_proc with the min num_proc of all fusible filters
228+
self.num_proc = min([op.runtime_np() for op in self.fused_ops]) \
229+
if self.fused_ops else 1
230+
231+
def process_batched(self, samples, rank=None):
232+
for op in self.fused_ops:
233+
process_args = {'rank': rank} if op.accelerator == 'cuda' else {}
234+
if isinstance(op, Mapper):
235+
samples = op.process_batched(samples, **process_args)
236+
elif isinstance(op, Filter):
237+
samples = op.compute_stats_batched(samples, **process_args)
238+
indicators = list(op.process_batched(samples))
239+
new_samples = {}
240+
for key in samples:
241+
new_samples[key] = [
242+
val for val, indicator in zip(samples[key], indicators)
243+
if indicator
244+
]
245+
samples = new_samples
246+
else:
247+
raise NotImplementedError(
248+
f'FusedOP does not support OP {op._name} of type '
249+
f'{type(op)} and only supports Mapper and Filter now.')
250+
return samples
251+
252+
def run(self, dataset, *, exporter=None, tracer=None):
253+
# prepare the dataset
254+
from data_juicer.core.data import NestedDataset
255+
if not isinstance(dataset, NestedDataset):
256+
dataset = NestedDataset(dataset)
257+
if not self.fused_ops:
258+
return dataset
259+
# initialize for different kinds of datasets
260+
for op in self.fused_ops:
261+
dataset = OP.run(op, dataset)
262+
263+
new_dataset = dataset.map(
264+
self.process_batched,
265+
num_proc=self.num_proc,
266+
with_rank=self.use_cuda(),
267+
batch_size=self.batch_size,
268+
desc=self._name + '_process',
269+
)
270+
return new_dataset

tests/ops/test_op_fusion.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
import unittest
22

3+
from data_juicer.core import NestedDataset
4+
from data_juicer.ops.base_op import OP
35
from data_juicer.ops.load import load_ops
4-
from data_juicer.ops.op_fusion import fuse_operators
6+
from data_juicer.ops.op_fusion import fuse_operators, GeneralFusedOP
57
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
68

79

810
class OpFusionTest(DataJuicerTestCaseBase):
911

12+
def _run_equal_config(self, original_process_list):
13+
dataset = NestedDataset.from_list([
14+
{'text': 'This is a test.'},
15+
{'text': 'This is a test. This is a test. This is a test.'},
16+
{'text': 'aaaaaaaaaaaaaaabbbbbbbbbbbbcccccccccccccc'},
17+
{'text': 'punc test。'}
18+
])
19+
unfused_op = load_ops(original_process_list)
20+
fused_ops = fuse_operators(unfused_op)
21+
res1 = dataset.process(fused_ops)
22+
res2 = dataset.process(unfused_op)
23+
self.assertDatasetEqual(res1, res2)
24+
1025
def _run_op_fusion(self, original_process_list, target_process_list, probe_res=None):
1126
ops = load_ops(original_process_list)
1227
ops = fuse_operators(ops, probe_res)
@@ -232,6 +247,7 @@ def test_regular_config(self):
232247
}
233248
]
234249
self._run_op_fusion(original_process, target_process)
250+
self._run_equal_config(original_process)
235251

236252
def test_only_mapper(self):
237253
original_process = [{
@@ -1961,5 +1977,125 @@ def test_different_intermediate_vars_with_probe_res(self):
19611977
self._run_op_fusion(original_process, target_process, probe_res_list)
19621978

19631979

1980+
class GeneralFusedOPTest(DataJuicerTestCaseBase):
1981+
1982+
def setUp(self) -> None:
1983+
self.dataset = NestedDataset.from_list([
1984+
{'text': 'This is a test.'},
1985+
{'text': 'This is a test. This is a test. This is a test.'},
1986+
{'text': 'aaaaaaaaaaaaaaabbbbbbbbbbbbcccccccccccccc'},
1987+
{'text': 'punc test。'}
1988+
])
1989+
1990+
def _run_equal_config(self, fused_process, unfused_process):
1991+
fused_op = load_ops(fused_process)
1992+
self.assertEqual(len(fused_op), 1)
1993+
fused_op = fused_op[0]
1994+
unfused_op = load_ops(unfused_process)
1995+
self.assertIsInstance(fused_op, GeneralFusedOP)
1996+
self.assertEqual(len(fused_op.fused_ops), len(unfused_process))
1997+
res1 = self.dataset.process(fused_op)
1998+
res2 = self.dataset.process(unfused_op)
1999+
# invoke process_batched directly
2000+
for op in fused_op.fused_ops:
2001+
self.dataset = OP.run(op, self.dataset)
2002+
res3 = fused_op.process_batched(self.dataset.to_dict())
2003+
self.assertDatasetEqual(res1, res2)
2004+
self.assertEqual(res1.to_dict(), res3)
2005+
2006+
def test_regular_config(self):
2007+
2008+
original_process = [{
2009+
'language_id_score_filter': {
2010+
'lang': 'en',
2011+
'min_score': 0.8,
2012+
'text_key': 'text'
2013+
}
2014+
}, {
2015+
'whitespace_normalization_mapper': {
2016+
'text_key': 'text'
2017+
}
2018+
}, {
2019+
'punctuation_normalization_mapper': {
2020+
'text_key': 'text'
2021+
}
2022+
}, {
2023+
'fix_unicode_mapper': {
2024+
'text_key': 'text'
2025+
}
2026+
}, {
2027+
'character_repetition_filter': {
2028+
'max_ratio': 0.106,
2029+
'min_ratio': 0.0,
2030+
'rep_len': 10,
2031+
'text_key': 'text'
2032+
}
2033+
}]
2034+
fused_process = [{
2035+
'general_fused_op': {
2036+
'batch_size': 2,
2037+
'fused_op_list': original_process,
2038+
}
2039+
}]
2040+
self._run_equal_config(fused_process, original_process)
2041+
2042+
def test_border_cases(self):
2043+
2044+
original_process = [{
2045+
'language_id_score_filter': {
2046+
'lang': 'en',
2047+
'min_score': 0.8,
2048+
'text_key': 'text'
2049+
}
2050+
}, {
2051+
'whitespace_normalization_mapper': {
2052+
'text_key': 'text'
2053+
}
2054+
}, {
2055+
'punctuation_normalization_mapper': {
2056+
'text_key': 'text'
2057+
}
2058+
}, {
2059+
'fix_unicode_mapper': {
2060+
'text_key': 'text'
2061+
}
2062+
}, {
2063+
'character_repetition_filter': {
2064+
'max_ratio': 0.106,
2065+
'min_ratio': 0.0,
2066+
'rep_len': 10,
2067+
'text_key': 'text'
2068+
}
2069+
}]
2070+
empty_fused_process = [{
2071+
'general_fused_op': {
2072+
'batch_size': 2,
2073+
'fused_op_list': None,
2074+
}
2075+
}]
2076+
fused_process = [{
2077+
'general_fused_op': {
2078+
'batch_size': 2,
2079+
'fused_op_list': original_process,
2080+
}
2081+
}]
2082+
# empty fused process
2083+
fused_op = load_ops(empty_fused_process)[0]
2084+
self.assertEqual(len(fused_op.fused_ops), 0)
2085+
res = fused_op.run(self.dataset)
2086+
self.assertDatasetEqual(res, self.dataset)
2087+
# unsupported fused op
2088+
with self.assertRaises(NotImplementedError):
2089+
fused_op = load_ops([{
2090+
'general_fused_op': {
2091+
'batch_size': 2,
2092+
'fused_op_list': [{
2093+
'document_deduplicator': {}
2094+
}],
2095+
}
2096+
}])[0]
2097+
fused_op.process_batched(self.dataset.to_dict())
2098+
2099+
19642100
if __name__ == '__main__':
19652101
unittest.main()

0 commit comments

Comments
 (0)