Skip to content

Commit 228c5ab

Browse files
authored
support depth_estimation & doc_qa (#1800)
1 parent e3dc096 commit 228c5ab

18 files changed

+495
-745
lines changed

mindnlp/core/ops/reduction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def all(input, dim=None, keepdim=False, *, dtype=None):
3838
def any(input, dim=None, keepdim=False):
3939
if use_pyboost():
4040
return mindspore.mint.any(input, dim, keepdim)
41-
any_ = _get_cache_prim(ops.ReduceAny)(keepdim)
42-
return any_(input, dim)
41+
return ops.any(input, dim)
4342

4443
# max
4544
def max(input, dim=None, keepdim=False):

mindnlp/transformers/models/auto/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
6767
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
6868
MODEL_FOR_VISION_2_SEQ_MAPPING,
69+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
6970
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
7071
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
7172
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
@@ -147,6 +148,7 @@
147148
'MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING',
148149
'MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING',
149150
'MODEL_FOR_VISION_2_SEQ_MAPPING',
151+
'MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING',
150152
'MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING',
151153
'MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING',
152154
'MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING',

mindnlp/transformers/models/auto/configuration_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
("musicgen_melody", "MusicgenMelodyConfig"),
155155
("mt5", "MT5Config"),
156156
("mvp", "MvpConfig"),
157+
("nougat", "VisionEncoderDecoderConfig"),
157158
("nystromformer", "NystromformerConfig"),
158159
("olmo", "OlmoConfig"),
159160
("oneformer", "OneFormerConfig"),

mindnlp/transformers/models/auto/feature_extraction_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
374374

375375
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
376376
feature_extractor_class = config_dict.get("feature_extractor_type", None)
377+
377378
feature_extractor_auto_map = None
378379
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
379380
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
@@ -392,6 +393,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
392393

393394
if feature_extractor_class is not None:
394395
return feature_extractor_class.from_dict(config_dict, **kwargs)
396+
397+
print(feature_extractor_class)
395398
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
396399
if type(config) in FEATURE_EXTRACTOR_MAPPING:
397400
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]

mindnlp/transformers/models/auto/modeling_auto.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,32 @@
731731
]
732732
)
733733

734+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
735+
[
736+
("blip", "BlipForConditionalGeneration"),
737+
("blip-2", "Blip2ForConditionalGeneration"),
738+
("chameleon", "ChameleonForConditionalGeneration"),
739+
("fuyu", "FuyuForCausalLM"),
740+
("git", "GitForCausalLM"),
741+
("idefics", "IdeficsForVisionText2Text"),
742+
("idefics2", "Idefics2ForConditionalGeneration"),
743+
("idefics3", "Idefics3ForConditionalGeneration"),
744+
("instructblip", "InstructBlipForConditionalGeneration"),
745+
("kosmos-2", "Kosmos2ForConditionalGeneration"),
746+
("llava", "LlavaForConditionalGeneration"),
747+
("llava_next", "LlavaNextForConditionalGeneration"),
748+
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
749+
("mllama", "MllamaForConditionalGeneration"),
750+
("paligemma", "PaliGemmaForConditionalGeneration"),
751+
("pix2struct", "Pix2StructForConditionalGeneration"),
752+
("pixtral", "LlavaForConditionalGeneration"),
753+
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
754+
("udop", "UdopForConditionalGeneration"),
755+
("vipllava", "VipLlavaForConditionalGeneration"),
756+
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
757+
]
758+
)
759+
734760
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
735761
[
736762
# Model for Masked LM mapping
@@ -1397,6 +1423,9 @@
13971423
CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
13981424
)
13991425
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
1426+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
1427+
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
1428+
)
14001429
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
14011430
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
14021431
)

mindnlp/transformers/models/dpt/image_processing_dpt.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
if is_mindspore_available():
4545
import mindspore
46-
from mindspore.ops import interpolate
46+
from mindnlp.core.nn.functional import interpolate
4747

4848
if is_vision_available():
4949
from PIL import Image
@@ -484,6 +484,44 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[List[T
484484

485485
return semantic_segmentation
486486

487+
def post_process_depth_estimation(
488+
self,
489+
outputs: "DepthEstimatorOutput",
490+
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
491+
) -> List[Dict[str, TensorType]]:
492+
"""
493+
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
494+
Only supports PyTorch.
495+
496+
Args:
497+
outputs ([`DepthEstimatorOutput`]):
498+
Raw outputs of the model.
499+
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
500+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
501+
(height, width) of each image in the batch. If left to None, predictions will not be resized.
502+
503+
Returns:
504+
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
505+
predictions.
506+
"""
507+
predicted_depth = outputs.predicted_depth
508+
509+
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
510+
raise ValueError(
511+
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
512+
)
513+
514+
results = []
515+
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
516+
for depth, target_size in zip(predicted_depth, target_sizes):
517+
if target_size is not None:
518+
depth = interpolate(
519+
depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
520+
).squeeze()
521+
522+
results.append({"predicted_depth": depth})
523+
524+
return results
487525

488526
__all__ = [
489527
'DPTImageProcessor',

0 commit comments

Comments
 (0)