Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@ process:
- text_length_filter: # filter text with length out of specific range
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
- text_pair_similarity_filter: # filter samples according to the similarity score between the text pair.
hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface
min_score: 0.1 # the min similarity score of filter range
max_score: 1.0 # the max similarity score of filter range
any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition
- token_num_filter: # filter text with total token number out of specific range
hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer
min_num: 10 # the min number of filter range
Expand Down
16 changes: 9 additions & 7 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_action_filter,
text_entity_dependency_filter, text_length_filter,
token_num_filter, video_aesthetics_filter,
video_aspect_ratio_filter, video_duration_filter,
video_frames_text_similarity_filter, video_motion_score_filter,
video_nsfw_filter, video_ocr_area_ratio_filter,
video_resolution_filter, video_tagging_from_frames_filter,
video_watermark_filter, word_repetition_filter,
words_num_filter)
text_pair_similarity_filter, token_num_filter,
video_aesthetics_filter, video_aspect_ratio_filter,
video_duration_filter, video_frames_text_similarity_filter,
video_motion_score_filter, video_nsfw_filter,
video_ocr_area_ratio_filter, video_resolution_filter,
video_tagging_from_frames_filter, video_watermark_filter,
word_repetition_filter, words_num_filter)
from .alphanumeric_filter import AlphanumericFilter
from .audio_duration_filter import AudioDurationFilter
from .audio_nmf_snr_filter import AudioNMFSNRFilter
Expand Down Expand Up @@ -47,6 +47,7 @@
from .text_action_filter import TextActionFilter
from .text_entity_dependency_filter import TextEntityDependencyFilter
from .text_length_filter import TextLengthFilter
from .text_pair_similarity_filter import TextPairSimilarityFilter
from .token_num_filter import TokenNumFilter
from .video_aesthetics_filter import VideoAestheticsFilter
from .video_aspect_ratio_filter import VideoAspectRatioFilter
Expand Down Expand Up @@ -104,6 +105,7 @@
'FlaggedWordFilter',
'WordRepetitionFilter',
'VideoMotionScoreFilter',
'TextPairSimilarityFilter'
]

# yapf: enable
105 changes: 105 additions & 0 deletions data_juicer/ops/filter/text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.ops.base_op import OPERATORS, Filter
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'text_pair_similarity_filter'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):

import torch
import transformers # noqa: F401

# avoid hanging when calling clip in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
class TextPairSimilarityFilter(Filter):
"""Filter to keep text pairs with similarities between texts
within a specific range."""

_accelerator = 'cuda'

def __init__(self,
hf_clip='openai/clip-vit-base-patch32',
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.

:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)

def compute_stats(self, sample, rank=None, context=False):

# check if it's computed already
if StatsKeys.text_pair_similarity in sample[Fields.stats]:
return sample

# there is no text in this sample
if (self.text_key not in sample or 'target_text' not in sample
or len(sample[self.text_key]) == 0
or len(sample['target_text']) == 0):
sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array(
[], dtype=np.float64)
return sample

model, processor = get_model(self.model_key, rank, self.use_cuda())

text1 = sample[self.text_key]
text2 = sample['target_text']

text_tensors = processor([text1, text2],
padding=True,
return_tensors='pt').to(model.device)
text_features = model.get_text_features(**text_tensors)

similarity = torch.cosine_similarity(text_features[0],
text_features[1],
dim=0)
sample[Fields.stats][StatsKeys.text_pair_similarity] = [similarity]

return sample

def process(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.text_pair_similarity]
if len(similarity) <= 0:
return True

keep_bools = np.array([
self.min_score <= sim_value <= self.max_score
for sim_value in similarity
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class StatsKeysConstant(object):
special_char_ratio = 'special_char_ratio'
stopwords_ratio = 'stopwords_ratio'
text_len = 'text_len'
text_pair_similarity = 'text_pair_similarity'
num_action = 'num_action'
num_dependency_edges = 'num_dependency_edges'
num_token = 'num_token'
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 43 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 42 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |

Expand Down Expand Up @@ -127,6 +127,7 @@ All the specific operators are listed below, each featured with several capabili
| text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts |
| text_entity_dependency_filter | General | en, zh | Keeps samples containing dependency edges for an entity in the dependency tree of the texts |
| text_length_filter | General | en, zh | Keeps samples with total text length within the specified range |
| text_pair_similarity_filter | General | en, zh | Keeps text pairs with text feature cosine similarity within the specified range based on a CLIP model |
| token_num_filter | General | en, zh | Keeps samples with token count within the specified range |
| video_aesthetics_filter | Video | - | Keeps samples whose specified frames have aesthetics scores within the specified range |
| video_aspect_ratio_filter | Video | - | Keeps samples containing videos with aspect ratios within the specified range |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 43 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 41 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 42 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -125,6 +125,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| text_action_filter | General | en, zh | 保留文本部分包含动作的样本 |
| text_entity_dependency_filter | General | en, zh | 保留文本部分的依存树中具有非独立实体的样本 |
| text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 |
| text_pair_similarity_filter | General | en, zh | 保留文本特征余弦相似度(基于CLIP模型)在指定范围内的样本 |
| token_num_filter | General | en, zh | 保留token数在指定范围内的样本 |
| video_aspect_ratio_filter | Video | - | 保留包含视频的宽高比在指定范围内的样本 |
| video_duration_filter | Video | - | 保留包含视频的时长在指定范围内的样本 |
Expand Down
62 changes: 62 additions & 0 deletions tests/ops/filter/test_text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import unittest

from data_juicer.core.data import NestedDataset as Dataset

from data_juicer.ops.filter.text_pair_similarity_filter import TextPairSimilarityFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)


class TextPairSimilarityFilterTest(DataJuicerTestCaseBase):

hf_clip = 'openai/clip-vit-base-patch32'


@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass(cls.hf_clip)

def _run_filter(self, dataset: Dataset, op, num_proc=1):

if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)

dataset = dataset.map(op.compute_stats,
num_proc=num_proc,
with_rank=True)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.select_columns(column_names=['text', 'target_text'])
res_list = dataset.to_list()
print(res_list)

def test_no_eoc_special_token(self):

ds_list = [{
'target_text': 'a lovely cat',
'text': 'a lovely cat',
}, {
'target_text': 'a lovely cat',
'text': 'a cute cat',
}, {
'target_text': 'a lovely cat',
'text': 'a black dog',
}]


dataset = Dataset.from_list(ds_list)
op = TextPairSimilarityFilter(hf_clip=self.hf_clip,
any_or_all='any',
min_score=0.1,
max_score=0.85)
self._run_filter(dataset, op)


if __name__ == '__main__':
unittest.main()
Loading