Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get("image_key", "images")
self.audio_key = kwargs.get("audio_key", "audios")
self.video_key = kwargs.get("video_key", "videos")
self.lidar_key = kwargs.get("lidar_key", "lidar")

self.query_key = kwargs.get("query_key", "query")
self.response_key = kwargs.get("response_key", "response")
Expand Down
49 changes: 49 additions & 0 deletions data_juicer/ops/mapper/lidar_segmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper

mmdet3d = LazyLoader("mmdet3d")

OP_NAME = "lidar_segmentation_mapper"


@OPERATORS.register_module(OP_NAME)
class LiDARSegmentationMapper(Mapper):
"""Mapper to do segmentation from LiDAR data."""

_batched_op = True
_accelerator = "cuda"

def __init__(self, model_name="cylinder3d", model_cfg_path="", model_path="", *args, **kwargs):
"""
Initialization method.
:param mode:
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)

self.model_name = model_name

if self.model_name == "cylinder3d":
self.model_cfg_path = model_cfg_path
self.model_path = model_path
else:
raise NotImplementedError(f'Only support "cylinder3d" for now, but got {self.model_name}')

self.model_key = prepare_model(
"mmlab", model_cfg=self.model_cfg_path, model_path=self.model_path, task="LiDARSegmentation"
)

def process_batched(self, samples, rank=None):
model = get_model(self.model_key, rank, self.use_cuda())

# lidars = []
# for temp_sample in samples[self.lidar_key]:
# lidars.append(dict(points=temp_sample))

results = [model(dict(points=lidar)) for lidar in samples[self.lidar_key]]
samples["lidar_segmentations"] = results

return samples
96 changes: 95 additions & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import redirect_stderr
from functools import partial
from pickle import UnpicklingError
from typing import Optional, Union
from typing import List, Optional, Union

import httpx
import multiprocess as mp
Expand Down Expand Up @@ -38,6 +38,8 @@
ultralytics = LazyLoader("ultralytics")
tiktoken = LazyLoader("tiktoken")
dashscope = LazyLoader("dashscope")
mmdeploy = LazyLoader("mmdeploy")
mmdet3d = LazyLoader("mmdet3d")

MODEL_ZOO = {}

Expand Down Expand Up @@ -942,6 +944,97 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl
return sampling_params


class MMLabModel(object):
"""
A wrapper for mmdeploy model.
It is used to load a mmdeploy model and run inference on given images.
"""

def __init__(self, model_cfg_path, deploy_cfg_path, backend_files, device):
self.model_cfg_path = model_cfg_path
self.deploy_cfg_path = deploy_cfg_path
self.backend_files = backend_files
self.device = device

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config

deploy_cfg, model_cfg = load_config(self.deploy_cfg_path, self.model_cfg_path)
self.task_processor = build_task_processor(model_cfg, deploy_cfg, self.device)

self.model = self.task_processor.build_backend_model(
self.backend_files, data_preprocessor_updater=self.task_processor.update_data_preprocessor
)

self.input_shape = get_input_shape(deploy_cfg)

def __call__(self, image):
model_inputs, _ = self.task_processor.create_input(image, self.input_shape)

with torch.no_grad():
result = self.model.test_step(model_inputs)

return result


class MMLabInferencer(object):
"""
A wrapper for mmdet3d Inferencer.
It is used to load a mmdet3d Inferencer and run inference on given images.
"""

def __init__(self, model_cfg_path, model_path, device):
self.model_cfg_path = model_cfg_path
self.model_path = model_path
self.device = device

from mmdet3d.apis import LidarSeg3DInferencer

self.model = LidarSeg3DInferencer(model=self.model_cfg_path, weights=self.model_path, device=self.device)

def __call__(self, lidar_bin_files):
result = self.model(lidar_bin_files, show=False)["predictions"]

return result


def prepare_mmlab_model(
model_cfg: str,
deploy_cfg: str = "",
backend_files: List[str] = [],
model_path: str = "",
device: str = "cpu",
task: str = "LiDARDetection",
):
"""Prepare and load a model using mmdeploy.

:param model_cfg: Path to the model config.
:param deploy_cfg: Path to the deployment config.
:param backend_files: Path to the backend model files.
:param device: Device to use.
:param task: Current task. Only support ["LiDARDetection", "LiDARSegmentation"] for now.
"""

if task == "LiDARDetection":
model = MMLabModel(
check_model(model_cfg),
check_model(deploy_cfg),
[check_model(backend_file) for backend_file in backend_files],
device,
)
elif task == "LiDARSegmentation":
model = MMLabInferencer(
model_cfg,
model_path,
device,
)

else:
NotImplementedError(f'Only support task name ["LiDARDetection", "LiDARSegmentation"] for now, but got {task}')

return model


MODEL_FUNCTION_MAPPING = {
"api": prepare_api_model,
"diffusion": prepare_diffusion_model,
Expand All @@ -960,6 +1053,7 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl
"video_blip": prepare_video_blip_model,
"vllm": prepare_vllm_model,
"embedding": prepare_embedding_model,
"mmlab": prepare_mmlab_model,
}

_MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"}
Expand Down
7 changes: 4 additions & 3 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Data-Juicer 中的算子分为以下 7 种类型。
| [filter](#filter) | 49 | Filters out low-quality samples. 过滤低质量样本。 |
| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 |
| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 |
| [mapper](#mapper) | 81 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
| [mapper](#mapper) | 82 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 |

All the specific operators are listed below, each featured with several capability tags.
Expand Down Expand Up @@ -97,7 +97,7 @@ All the specific operators are listed below, each featured with several capabili
| flagged_words_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with flagged-word ratio less than a specific max value. 过滤以保持标记词比率小于特定最大值的样本。 | [code](../data_juicer/ops/filter/flagged_words_filter.py) | [tests](../tests/ops/filter/test_flagged_words_filter.py) |
| general_field_filter | 💻CPU 🟡Beta | Filter to keep samples based on a general field filter condition. 根据常规字段筛选条件保留样本。 | [code](../data_juicer/ops/filter/general_field_filter.py) | [tests](../tests/ops/filter/test_general_field_filter.py) |
| image_aesthetics_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples with aesthetics scores within a specific range. 过滤以保持美学分数在特定范围内的样品。 | [code](../data_juicer/ops/filter/image_aesthetics_filter.py) | [tests](../tests/ops/filter/test_image_aesthetics_filter.py) |
| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持特定范围内的图像长宽比的样本。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) |
| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持样本的图像纵横比在特定范围内。 | [code](../data_juicer/ops/filter/image_aspect_ratio_filter.py) | [tests](../tests/ops/filter/test_image_aspect_ratio_filter.py) |
| image_face_count_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with the number of faces within a specific range. 过滤以保持样本的面数在特定范围内。 | [code](../data_juicer/ops/filter/image_face_count_filter.py) | [tests](../tests/ops/filter/test_image_face_count_filter.py) |
| image_face_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with face area ratios within a specific range. 过滤以保持面面积比在特定范围内的样本。 | [code](../data_juicer/ops/filter/image_face_ratio_filter.py) | [tests](../tests/ops/filter/test_image_face_ratio_filter.py) |
| image_nsfw_filter | 🏞Image 💻CPU 🧩HF 🟢Stable | Filter to keep samples whose images have low nsfw scores. 过滤器保留图像具有低nsfw分数的样本。 | [code](../data_juicer/ops/filter/image_nsfw_filter.py) | [tests](../tests/ops/filter/test_image_nsfw_filter.py) |
Expand Down Expand Up @@ -196,10 +196,11 @@ All the specific operators are listed below, each featured with several capabili
| image_diffusion_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Generate image by diffusion model. 通过扩散模型生成图像。 | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) |
| image_face_blur_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to blur faces detected in images. 映射器模糊图像中检测到的人脸。 | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) |
| image_remove_background_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to remove background of images. 映射器删除图像的背景。 | [code](../data_juicer/ops/mapper/image_remove_background_mapper.py) | [tests](../tests/ops/mapper/test_image_remove_background_mapper.py) |
| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 对图像执行segment-任何操作并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) |
| image_segment_mapper | 🏞Image 💻CPU 🟢Stable | Perform segment-anything on images and return the bounding boxes. 在图像上执行segment-anything并返回边界框。 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) |
| image_tagging_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) |
| imgdiff_difference_area_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_area_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_area_generator_mapper.py) |
| imgdiff_difference_caption_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_caption_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_caption_generator_mapper.py) |
| lidar_segmentation_mapper | 💻CPU 🟡Beta | Mapper to do segmentation from LiDAR data. 映射器从激光雷达数据中进行分割。 | [code](../data_juicer/ops/mapper/lidar_segmentation_mapper.py) | [tests](../tests/ops/mapper/test_lidar_segmentation_mapper.py) |
| mllm_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [code](../data_juicer/ops/mapper/mllm_mapper.py) | [tests](../tests/ops/mapper/test_mllm_mapper.py) |
| nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in English based on nlpaug library. 映射器基于nlpaug库简单地增加英语样本。 | [code](../data_juicer/ops/mapper/nlpaug_en_mapper.py) | [tests](../tests/ops/mapper/test_nlpaug_en_mapper.py) |
| nlpcda_zh_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in Chinese based on nlpcda library. 基于nlpcda库的映射器可以简单地增加中文样本。 | [code](../data_juicer/ops/mapper/nlpcda_zh_mapper.py) | [tests](../tests/ops/mapper/test_nlpcda_zh_mapper.py) |
Expand Down
Binary file added tests/ops/data/lidar_test1.bin
Binary file not shown.
Binary file added tests/ops/data/lidar_test2.bin
Binary file not shown.
Binary file added tests/ops/data/lidar_test3.bin
Binary file not shown.
71 changes: 71 additions & 0 deletions tests/ops/mapper/test_lidar_segmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest
import os

from data_juicer.core import NestedDataset as Dataset
from data_juicer.ops.mapper.lidar_segmentation_mapper import LiDARSegmentationMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class LiDARSegmentationMapperTest(DataJuicerTestCaseBase):
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
'data')
lidar_test1 = os.path.join(data_path, 'lidar_test1.bin')
lidar_test2 = os.path.join(data_path, 'lidar_test2.bin')
lidar_test3 = os.path.join(data_path, 'lidar_test3.bin')

model_cfg_path = "cylinder3d_8xb2-laser-polar-mix-3x_semantickitti.py"
model_path = "cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth"


def test_cpu(self):
source = [
{
'lidar': self.lidar_test1
},
{
'lidar': self.lidar_test2
},
{
'lidar': self.lidar_test3
}
]

op = LiDARSegmentationMapper(
model_name="cylinder3d",
model_cfg_path=self.model_cfg_path,
model_path=self.model_path,
)

dataset = Dataset.from_list(source)
dataset = dataset.map(op.process, batch_size=2, with_rank=False)

print(dataset)


def test_cuda(self):
source = [
{
'lidar': self.lidar_test1
},
{
'lidar': self.lidar_test2
},
{
'lidar': self.lidar_test3
}
]

op = LiDARSegmentationMapper(
model_name="cylinder3d",
model_cfg_path=self.model_cfg_path,
model_path=self.model_path,
)

dataset = Dataset.from_list(source)
dataset = dataset.map(op.process, batch_size=2, with_rank=True)

print(dataset)


if __name__ == "__main__":
unittest.main()
Loading