From 216bf507c775b5501ef3d7e95ae88508d63c8ecb Mon Sep 17 00:00:00 2001 From: lingzhq Date: Tue, 22 Jul 2025 15:39:35 +0800 Subject: [PATCH] Add an op GroupDiversityFilter --- configs/config_all.yaml | 13 +- data_juicer/ops/filter/__init__.py | 4 +- .../ops/filter/group_diversity_filter.py | 140 ++++++++++++++++++ data_juicer/utils/constant.py | 2 + docs/Operators.md | 11 +- .../ops/filter/test_group_diversity_filter.py | 76 ++++++++++ 6 files changed, 238 insertions(+), 8 deletions(-) create mode 100644 data_juicer/ops/filter/group_diversity_filter.py create mode 100644 tests/ops/filter/test_group_diversity_filter.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 73f39e2bd3..7a1694a8c1 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -649,8 +649,17 @@ process: use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese words_aug_group_sizes: [2] # the group size of words to augment words_aug_join_char: "" # the join char between words to augment - - general_field_filter: # Filter to keep samples based on a general field filter condition. - filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons. + - general_field_filter: # Filter to keep samples based on a general field filter condition. + filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons. + - group_diversity_filter: # filter samples based on their semantic diversity within a group. + api_or_hf_model: "text-embedding-v3" # API or huggingface embedding model name. + is_hf_model: false # indicates if the model is from HuggingFace. + api_endpoint: "/embeddings" # embedding URL endpoint for the API. + response_path: "data.0.embedding" # path to extract content from the API response. + ebd_dim: 512 # the embedding's dimension via API. + min_score: 0.0 # the min score of filter range + max_score: 1.0 # the max score of filter range + norm_ratio: 0.5 # ratio to normalize the score. - image_aesthetics_filter: # filter samples according to the aesthetics score of images. hf_scorer_model: shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE # Huggingface model name for the aesthetics predictor min_score: 0.3 # the min aesthetics score of filter range diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 253a0ee96f..5915884bf2 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -6,6 +6,7 @@ from .character_repetition_filter import CharacterRepetitionFilter from .flagged_words_filter import FlaggedWordFilter from .general_field_filter import GeneralFieldFilter +from .group_diversity_filter import GroupDiversityFilter from .image_aesthetics_filter import ImageAestheticsFilter from .image_aspect_ratio_filter import ImageAspectRatioFilter from .image_face_count_filter import ImageFaceCountFilter @@ -56,6 +57,8 @@ "AverageLineLengthFilter", "CharacterRepetitionFilter", "FlaggedWordFilter", + "GeneralFieldFilter", + "GroupDiversityFilter", "ImageAestheticsFilter", "ImageAspectRatioFilter", "ImageFaceCountFilter", @@ -97,7 +100,6 @@ "VideoWatermarkFilter", "WordRepetitionFilter", "WordsNumFilter", - "GeneralFieldFilter", ] NON_STATS_FILTERS = [ diff --git a/data_juicer/ops/filter/group_diversity_filter.py b/data_juicer/ops/filter/group_diversity_filter.py new file mode 100644 index 0000000000..0fd81f67b8 --- /dev/null +++ b/data_juicer/ops/filter/group_diversity_filter.py @@ -0,0 +1,140 @@ +import sys +from typing import Dict, List + +import numpy as np +from jsonargparse.typing import NonNegativeFloat, PositiveInt +from tqdm import tqdm + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Filter + +# Lazy load torch to improve startup time +torch = LazyLoader("torch") + + +@OPERATORS.register_module("group_diversity_filter") +class GroupDiversityFilter(Filter): + """ + Filter samples based on their semantic diversity within a group. + """ + + _accelerator = "cuda" + _batched_op = True + + def __init__( + self, + api_or_hf_model: str = "text-embedding-v3", + is_hf_model: bool = False, + api_endpoint: str = "/embeddings", + response_path: str = "data.0.embedding", + model_params: Dict = {}, + ebd_dim: PositiveInt = 512, + min_score: NonNegativeFloat = 0.0, + max_score: NonNegativeFloat = 1.0, + norm_ratio: NonNegativeFloat = 0.5, + *args, + **kwargs, + ): + """ + Initialization method. + + :param api_or_hf_model: API or huggingface embedding model name. + :param is_hf_model: Indicates if the model is from HuggingFace. + :param api_endpoint: Embedding URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'data.0.embedding' for embedding model. + :param model_params: Parameters for initializing the API model. + :param ebd_dim: The embedding's dimension via API. + :param min_score: Minimum score for filtering. + :param max_score: Maximum score for filtering. + :param norm_ratio: Ratio to normalize the score. + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault("mem_required", "20GB") + super().__init__(*args, **kwargs) + + self.min_score = min_score + self.max_score = max_score + self.norm_ratio = norm_ratio + self.is_hf_model = is_hf_model + self.ebd_dim = ebd_dim + + if self.is_hf_model: + self.model_key = prepare_model(model_type="embedding", model_path=api_or_hf_model, **model_params) + else: + self.model_key = prepare_model( + model_type="api", + model=api_or_hf_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params, + ) + + def _embed_texts(self, texts: List[str], rank: int) -> np.ndarray: + # Embed a list of texts using the initialized model + embeddings = [] + model = get_model(self.model_key, rank, self.use_cuda()) + + for text in tqdm(texts, desc="Embedding texts", leave=False): + try: + if self.is_hf_model: + embedding = model.encode(text) + else: + embedding = model(text, dimensions=self.ebd_dim, encoding_format="float") + embeddings.append(np.array(embedding, dtype=np.float32)) + except Exception as e: + dim = model.get_sentence_embedding_dimension() if self.is_hf_model else self.ebd_dim + embeddings.append(np.zeros(dim, dtype=np.float32)) + print(f"Failed to embed text: '{text}'. Error: {e}. Using zero vector.", file=sys.stderr) + + return np.array(embeddings) + + def compute_stats_batched(self, samples: Dict, rank: int = 0) -> Dict: + stats_list = samples[Fields.stats] + if stats_list and StatsKeys.text_ebd_diversity_score in stats_list[0]: + return samples + + texts_to_embed = samples[self.text_key] + if not texts_to_embed: + for stat in stats_list: + stat[StatsKeys.text_ebd_diversity] = 0.0 + stat[StatsKeys.text_ebd_diversity_score] = 0.0 + return samples + + embeddings_array = self._embed_texts(texts_to_embed, rank=rank) + + avg_embedding = np.mean(embeddings_array, axis=0) + + cos_sims = ( + torch.nn.functional.cosine_similarity( + torch.from_numpy(embeddings_array), torch.from_numpy(avg_embedding).unsqueeze(0), dim=1 + ) + .cpu() + .numpy() + .tolist() + ) + + min_sim, max_sim = min(cos_sims), max(cos_sims) + range_sim = max_sim - min_sim + + normalized_scores = [] + if range_sim < 1e-8: + normalized_scores = [0.0] * len(cos_sims) + else: + for sim in cos_sims: + normalized_sim = self.norm_ratio * (max_sim - sim) / range_sim + normalized_scores.append(normalized_sim) + + for i, stat in enumerate(stats_list): + stat[StatsKeys.text_ebd_diversity] = cos_sims[i] + stat[StatsKeys.text_ebd_diversity_score] = normalized_scores[i] + + return samples + + def process_batched(self, samples: Dict) -> List[bool]: + stats_list = samples[Fields.stats] + return [self.min_score <= stat[StatsKeys.text_ebd_diversity_score] <= self.max_score for stat in stats_list] diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index e3244de246..d658533410 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -230,6 +230,8 @@ class StatsKeysConstant(object): llm_quality_record = "llm_quality_record" llm_difficulty_score = "llm_difficulty_score" llm_difficulty_record = "llm_difficulty_record" + text_ebd_diversity = "text_ebd_diversity" + text_ebd_diversity_score = "text_ebd_diversity_score" # === image === aspect_ratios = "aspect_ratios" diff --git a/docs/Operators.md b/docs/Operators.md index d1abbbb782..cfe96c47ff 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -34,7 +34,7 @@ Data-Juicer 中的算子分为以下 7 种类型。 |------|:------:|-------------| | [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | | [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | -| [filter](#filter) | 49 | Filters out low-quality samples. 过滤低质量样本。 | +| [filter](#filter) | 50 | Filters out low-quality samples. 过滤低质量样本。 | | [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | | [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | | [mapper](#mapper) | 81 | Edits and transforms samples. 对数据样本进行编辑和转换。 | @@ -96,8 +96,9 @@ All the specific operators are listed below, each featured with several capabili | character_repetition_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with char-level n-gram repetition ratio within a specific range. 过滤器将具有char级n-gram重复比率的样本保持在特定范围内。 | [code](../data_juicer/ops/filter/character_repetition_filter.py) | [tests](../tests/ops/filter/test_character_repetition_filter.py) | | flagged_words_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with flagged-word ratio less than a specific max value. 过滤以保持标记词比率小于特定最大值的样本。 | [code](../data_juicer/ops/filter/flagged_words_filter.py) | [tests](../tests/ops/filter/test_flagged_words_filter.py) | | general_field_filter | 💻CPU 🟡Beta | Filter to keep samples based on a general field filter condition. 根据常规字段筛选条件保留样本。 | [code](../data_juicer/ops/filter/general_field_filter.py) | [tests](../tests/ops/filter/test_general_field_filter.py) | +| group_diversity_filter | 🔤Text 💻CPU 🔗API 🟡Beta | Filter samples based on their semantic diversity within a group. 基于样本在组内的语义多样性来过滤样本。 | [code](../data_juicer/ops/filter/group_diversity_filter.py) | [tests](../tests/ops/filter/test_group_diversity_filter.py) | | image_aesthetics_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples with aesthetics scores within a specific range. 过滤以保持美学分数在特定范围内的样品。 | [code](../data_juicer/ops/filter/image_aesthetics_filter.py) | [tests](../tests/ops/filter/test_image_aesthetics_filter.py) | -| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持特定范围内的图像长宽比的样本。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) | +| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持样本的图像纵横比在特定范围内。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) | | image_face_count_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with the number of faces within a specific range. 过滤以保持样本的面数在特定范围内。 | [code](../data_juicer/ops/filter/image_face_count_filter.py) | [tests](../tests/ops/filter/test_image_face_count_filter.py) | | image_face_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with face area ratios within a specific range. 过滤以保持面面积比在特定范围内的样本。 | [code](../data_juicer/ops/filter/image_face_ratio_filter.py) | [tests](../tests/ops/filter/test_image_face_ratio_filter.py) | | image_nsfw_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples whose images have low nsfw scores. 过滤器保留图像具有低nsfw分数的样本。 | [code](../data_juicer/ops/filter/image_nsfw_filter.py) | [tests](../tests/ops/filter/test_image_nsfw_filter.py) | @@ -148,7 +149,7 @@ All the specific operators are listed below, each featured with several capabili | local_formatter | 🟢Stable | The class is used to load a dataset from local files or local directory. 类用于从本地文件或本地目录加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) | | parquet_formatter | 🟢Stable | The class is used to load and format parquet-type files. 该类用于加载和格式化镶木地板类型的文件。 | [code](../data_juicer/format/parquet_formatter.py) | [tests](../tests/format/test_parquet_formatter.py) | | remote_formatter | 🟢Stable | The class is used to load a dataset from repository of huggingface hub. 该类用于从huggingface hub的存储库加载数据集。 | [code](../data_juicer/format/formatter.py) | [tests](../tests/format/test_unify_format.py) | -| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型文件。 | [code](../data_juicer/format/text_formatter.py) | - | +| text_formatter | 🔴Alpha | The class is used to load and format text-type files. 类用于加载和格式化文本类型的文件。 | [code](../data_juicer/format/text_formatter.py) | - | | tsv_formatter | 🟢Stable | The class is used to load and format tsv-type files. 该类用于加载和格式化tsv类型的文件。 | [code](../data_juicer/format/tsv_formatter.py) | [tests](../tests/format/test_tsv_formatter.py) | ## grouper @@ -165,7 +166,7 @@ All the specific operators are listed below, each featured with several capabili |----------|------|-------------|-------------|------------| | audio_add_gaussian_noise_mapper | 📣Audio 💻CPU 🟡Beta | Mapper to add gaussian noise to audio. 映射器向音频添加高斯噪声。 | [code](../data_juicer/ops/mapper/audio_add_gaussian_noise_mapper.py) | [tests](../tests/ops/mapper/test_audio_add_gaussian_noise_mapper.py) | | audio_ffmpeg_wrapped_mapper | 📣Audio 💻CPU 🟢Stable | Simple wrapper for FFmpeg audio filters. FFmpeg音频滤波器的简单包装。 | [code](../data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py) | -| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问题-答案对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | +| calibrate_qa_mapper | 🔤Text 💻CPU 🔗API 🟢Stable | Mapper to calibrate question-answer pairs based on reference text. 映射器基于参考文本校准问答对。 | [code](../data_juicer/ops/mapper/calibrate_qa_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_qa_mapper.py) | | calibrate_query_mapper | 💻CPU 🟢Stable | Mapper to calibrate query in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的查询。 | [code](../data_juicer/ops/mapper/calibrate_query_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_query_mapper.py) | | calibrate_response_mapper | 💻CPU 🟢Stable | Mapper to calibrate response in question-answer pairs based on reference text. 映射器基于参考文本校准问答对中的响应。 | [code](../data_juicer/ops/mapper/calibrate_response_mapper.py) | [tests](../tests/ops/mapper/test_calibrate_response_mapper.py) | | chinese_convert_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. 映射器在繁体中文,简体中文和日语汉字之间转换中文。 | [code](../data_juicer/ops/mapper/chinese_convert_mapper.py) | [tests](../tests/ops/mapper/test_chinese_convert_mapper.py) | @@ -196,7 +197,7 @@ All the specific operators are listed below, each featured with several capabili | image_diffusion_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Generate image by diffusion model. 通过扩散模型生成图像。 | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) | | image_face_blur_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to blur faces detected in images. 映射器模糊图像中检测到的人脸。 | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) | | image_remove_background_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to remove background of images. 映射器删除图像的背景。 | [code](../data_juicer/ops/mapper/image_remove_background_mapper.py) | [tests](../tests/ops/mapper/test_image_remove_background_mapper.py) | -| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 在图像上执行segment-anything并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | +| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 对图像执行segment-任何操作并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | | image_tagging_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) | | imgdiff_difference_area_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_area_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_area_generator_mapper.py) | | imgdiff_difference_caption_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_caption_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_caption_generator_mapper.py) | diff --git a/tests/ops/filter/test_group_diversity_filter.py b/tests/ops/filter/test_group_diversity_filter.py new file mode 100644 index 0000000000..94e520ebe4 --- /dev/null +++ b/tests/ops/filter/test_group_diversity_filter.py @@ -0,0 +1,76 @@ +import os +import unittest +from loguru import logger +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.filter.group_diversity_filter import GroupDiversityFilter +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, FROM_FORK + +@unittest.skipIf(FROM_FORK, "Skipping the test because running from a fork repo") +class GroupDiversityFilterTest(DataJuicerTestCaseBase): + # before running this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + api_model_name = 'text-embedding-v3' + api_ebd_dim = 512 + + # For local Hugging Face model test + hf_model_path = 'iic/gte_Qwen2-1.5B-instruct' + def setUp(self): + self.ds_list = [{ + 'text': "A cute cat is playing in the garden." + }, { + 'text': "A lovely dog is running on the grass." + }, { + 'text': "A beautiful bird is singing on the tree." + }, { + 'text': "Quantum computing is a complex field of physics." # The outlier + }] + self.dataset = Dataset.from_list(self.ds_list) + + def test_api_based_diversity_logic(self): + if not os.getenv('OPENAI_API_KEY'): + self.skipTest("OPENAI_API_KEY environment variable is not set. " + "Skipping API-based integration test.") + + logger.info(f"Running diversity test with API model: {self.api_model_name}") + op = GroupDiversityFilter( + api_or_hf_model=self.api_model_name, + is_hf_model=False, + ebd_dim=self.api_ebd_dim + ) + self._run_and_assert_diversity(op) + + def test_hf_based_diversity_logic(self): + logger.info(f"Running diversity test with HF model: {self.hf_model_path}") + op = GroupDiversityFilter( + api_or_hf_model=self.hf_model_path, + is_hf_model=True, + ) + self._run_and_assert_diversity(op) + + def _run_and_assert_diversity(self, op: GroupDiversityFilter): + dataset = self.dataset.add_column(name=Fields.stats, column=[{}] * len(self.dataset)) + dataset = dataset.map(op.compute_stats_batched, + with_rank=True, + batched=True, + batch_size=len(self.dataset)) + + stats_list = dataset.to_list() + for sample in stats_list: + logger.info(f"Text: '{sample['text']}', " + f"Score: {sample[Fields.stats].get(StatsKeys.text_ebd_diversity_score, 'N/A')}") + + scores = [d[Fields.stats][StatsKeys.text_ebd_diversity_score] for d in stats_list] + + outlier_score = scores[-1] + other_scores = scores[:-1] + + self.assertTrue(all(outlier_score > score for score in other_scores), + "The outlier sample did not receive the highest diversity score.") + + logger.info("Test passed: The outlier sample correctly received the highest diversity score.") + + +if __name__ == '__main__': + unittest.main()