diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4b3a17d4b52d..a09d8667bbf0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1031,6 +1031,8 @@ title: PaliGemma - local: model_doc/perceiver title: Perceiver + - local: model_doc/perception_lm + title: PerceptionLM - local: model_doc/phi4_multimodal title: Phi4 Multimodal - local: model_doc/pix2struct diff --git a/docs/source/en/model_doc/perception_lm.md b/docs/source/en/model_doc/perception_lm.md new file mode 100644 index 000000000000..3982d521b949 --- /dev/null +++ b/docs/source/en/model_doc/perception_lm.md @@ -0,0 +1,68 @@ + + +# PerceptionLM + +## Overview + +The PerceptionLM model was proposed in [PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding](https://ai.meta.com/research/publications/perceptionlm-open-access-data-and-models-for-detailed-visual-understanding/) by Jang Hyun Cho et al. It's a fully open, reproducible model for transparent research in image and video understanding. PLM consists of +a vision encoder with a small scale (<8B parameters) LLM decoder. + +The abstract from the paper is the following: + +*Vision-language models are integral to computer vision research, yet many high-performing models +remain closed-source, obscuring their data, design and training recipe. The research community +has responded by using distillation from black-box models to label training data, achieving strong +benchmark results, at the cost of measurable scientific progress. However, without knowing the details +of the teacher model and its data sources, scientific progress remains difficult to measure. In this +paper, we study building a Perception Language Model (PLM) in a fully open and reproducible +framework for transparent research in image and video understanding. We analyze standard training +pipelines without distillation from proprietary models and explore large-scale synthetic data to identify +critical data gaps, particularly in detailed video understanding. To bridge these gaps, we release 2.8M +human-labeled instances of fine-grained video question-answer pairs and spatio-temporally grounded +video captions. Additionally, we introduce PLM–VideoBench, a suite for evaluating challenging video +understanding tasks focusing on the ability to reason about “what”, “where”, “when”, and “how” of a +video. We make our work fully reproducible by providing data, training recipes, code & models.* + + +This model was contributed by [shumingh](https://huggingface.co/shumingh). +The original code can be found [here](https://github.com/facebookresearch/perception_models). + + +## PerceptionLMConfig + +[[autodoc]] PerceptionLMConfig + +## PerceptionLMProcessor + +[[autodoc]] PerceptionLMProcessor + +## PerceptionLMImageProcessorFast + +[[autodoc]] PerceptionLMImageProcessorFast + +## PerceptionLMVideoProcessor + +[[autodoc]] PerceptionLMVideoProcessor + +## PerceptionLMModel + +[[autodoc]] PerceptionLMModel + +## PerceptionLMForConditionalGeneration + +[[autodoc]] PerceptionLMForConditionalGeneration + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7b2332d89f40..b786170ef205 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -234,6 +234,7 @@ from .pegasus import * from .pegasus_x import * from .perceiver import * + from .perception_lm import * from .persimmon import * from .phi import * from .phi3 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d7bf78fefe87..7c6c7a8ec779 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -267,6 +267,8 @@ ("pegasus", "PegasusConfig"), ("pegasus_x", "PegasusXConfig"), ("perceiver", "PerceiverConfig"), + ("perception_encoder", "TimmWrapperConfig"), + ("perception_lm", "PerceptionLMConfig"), ("persimmon", "PersimmonConfig"), ("phi", "PhiConfig"), ("phi3", "Phi3Config"), @@ -663,6 +665,8 @@ ("pegasus", "Pegasus"), ("pegasus_x", "PEGASUS-X"), ("perceiver", "Perceiver"), + ("perception_encoder", "PerceptionEncoder"), + ("perception_lm", "PerceptionLM"), ("persimmon", "Persimmon"), ("phi", "Phi"), ("phi3", "Phi3"), @@ -869,6 +873,7 @@ ("llama4_text", "llama4"), ("blip_2_qformer", "blip_2"), ("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"), + ("perception_encoder", "perception_lm"), ] ) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 64666456075e..4765da007888 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -132,6 +132,7 @@ ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), + ("perception_lm", ("PerceptionLMImageProcessorFast",)), ("phi4_multimodal", ("Phi4MultimodalImageProcessorFast",)), ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), @@ -597,7 +598,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): raise ValueError( "This image processor cannot be instantiated. Please make sure you have `Pillow` installed." ) - raise ValueError( f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a " f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following " diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 075e8e31f15b..7c2a7d130997 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -255,6 +255,8 @@ ("pegasus", "PegasusModel"), ("pegasus_x", "PegasusXModel"), ("perceiver", "PerceiverModel"), + ("perception_encoder", "PerceptionEncoder"), + ("perception_lm", "PerceptionLMModel"), ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), @@ -933,6 +935,7 @@ ("mistral3", "Mistral3ForConditionalGeneration"), ("mllama", "MllamaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), + ("perception_lm", "PerceptionLMForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("pixtral", "LlavaForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index e5bd673f6390..3d8c54e0d6f8 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -100,6 +100,7 @@ ("owlv2", "Owlv2Processor"), ("owlvit", "OwlViTProcessor"), ("paligemma", "PaliGemmaProcessor"), + ("perception_lm", "PerceptionLMProcessor"), ("phi4_multimodal", "Phi4MultimodalProcessor"), ("pix2struct", "Pix2StructProcessor"), ("pixtral", "PixtralProcessor"), diff --git a/src/transformers/models/perception_lm/__init__.py b/src/transformers/models/perception_lm/__init__.py new file mode 100644 index 000000000000..81c3ba93bcf4 --- /dev/null +++ b/src/transformers/models/perception_lm/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_perception_lm import * + from .image_processing_perception_lm_fast import * + from .modeling_perception_lm import * + from .processing_perception_lm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/perception_lm/configuration_perception_lm.py b/src/transformers/models/perception_lm/configuration_perception_lm.py new file mode 100644 index 000000000000..12352967d7c7 --- /dev/null +++ b/src/transformers/models/perception_lm/configuration_perception_lm.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PerceptionLM model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig +from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig + + +logger = logging.get_logger(__name__) + + +class PerceptionLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PerceptionLMForConditionalGeneration`]. It is used to instantiate an + PerceptionLM model according to the specified arguments, defining the model architecture. + + Example models: + - [facebook/Perception-LM-1B](https://huggingface.co/facebook/Perception-LM-1B). + - [facebook/Perception-LM-3B](https://huggingface.co/facebook/Perception-LM-3B). + - [facebook/Perception-LM-8B](https://huggingface.co/facebook/Perception-LM-8B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[TimmWrapperConfig, dict]`, *optional*, defaults to `TimmWrapperConfig()`): + The config object or dictionary of the vision backbone. + text_config (`Union[PretrainedConfig, dict]`, *optional*, defaults to `LlamaConfig()`): + The config object or dictionary of the text backbone. + vision_use_cls_token (`bool`, *optional*, defaults to `True`): + Whether CLS token is used in the vision backbone. If used, we remove CLS token embedding from vision output. + projector_pooling_ratio (`int`, *optional*, defaults to 1): + The pooling ratio used in the multimodal projector. + image_token_id (`int`, *optional*, defaults to 128002): + The image token index to encode the image prompt. + video_token_id (`int`, *optional*, defaults to 128003): + The video token index to encode the video prompt. + """ + + model_type = "perception_lm" + sub_configs = {"text_config": AutoConfig, "vision_config": TimmWrapperConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + vision_use_cls_token=True, + projector_pooling_ratio=1, + image_token_id=128002, + video_token_id=128003, + **kwargs, + ): + self.image_token_id = image_token_id + self.video_token_id = video_token_id + if isinstance(vision_config, dict): + vision_config = TimmWrapperConfig(**vision_config) + elif isinstance(vision_config, TimmWrapperConfig): + vision_config = vision_config + elif vision_config is None: + vision_config = TimmWrapperConfig() + self.vision_config = vision_config + self.vision_use_cls_token = vision_use_cls_token + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.projector_pooling_ratio = projector_pooling_ratio + super().__init__(**kwargs) + + +__all__ = ["PerceptionLMConfig"] diff --git a/src/transformers/models/perception_lm/convert_perception_lm_weights_to_hf.py b/src/transformers/models/perception_lm/convert_perception_lm_weights_to_hf.py new file mode 100644 index 000000000000..ee96c86876dd --- /dev/null +++ b/src/transformers/models/perception_lm/convert_perception_lm_weights_to_hf.py @@ -0,0 +1,615 @@ +# coding=utf-8 +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import tempfile +import warnings + +import torch +from timm.models.eva import checkpoint_filter_fn +from tokenizers import AddedToken, processors + +from transformers import ( + GenerationConfig, + LlamaConfig, + LlamaTokenizer, + PreTrainedTokenizerFast, +) +from transformers.convert_slow_tokenizer import TikTokenConverter +from transformers.models.auto.modeling_auto import AutoModel +from transformers.models.perception_lm.configuration_perception_lm import ( + PerceptionLMConfig, +) +from transformers.models.perception_lm.image_processing_perception_lm_fast import ( + PerceptionLMImageProcessorFast, +) +from transformers.models.perception_lm.modeling_perception_lm import ( + PerceptionLMForConditionalGeneration, +) +from transformers.models.perception_lm.processing_perception_lm import ( + PerceptionLMProcessor, +) +from transformers.models.perception_lm.video_processing_perception_lm import ( + PerceptionLMVideoProcessor, +) +from transformers.models.timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig + + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python src/transformers/models/perception_lm/convert_perception_lm_weights_to_hf.py \ + --input_dir /path/to/downloaded/perception_lm/model_path --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). + +If you want your tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor: + +```py +from tokenizers import processors +bos = "<|begin_of_text|>" +tokenizer._tokenizers.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single=f"{bos}:0 $A:0", + pair=f"{bos}:0 $A:0 {bos}:1 $B:1", + special_tokens=[ + (bos, tokenizer.encode(bos)), + ], + ), + ] +) +``` + +""" + +BOS_ADDED_TOKEN = AddedToken( + "<|begin_of_text|>", + single_word=False, + lstrip=False, + rstrip=False, + normalized=False, + special=True, +) +EOS_ADDED_TOKEN = AddedToken( + "<|end_of_text|>", + single_word=False, + lstrip=False, + rstrip=False, + normalized=False, + special=True, +) +EOT_ADDED_TOKEN = AddedToken( + "<|eot_id|>", + single_word=False, + lstrip=False, + rstrip=False, + normalized=False, + special=True, +) + +DEFAULT_SPECIAL_TOKENS = { + "perception_lm": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|image|>", + "<|video|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # End of turn + ] + + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] +} + +CHAT_TEMPLATE = ( + "{{- bos_token }}" + "{%- if messages[0]['role'] == 'system' -%}" + " {%- set system_message = messages[0]['content']|trim %}\n" + " {%- set messages = messages[1:] %}\n" + "{%- else %}" + " {%- set system_message = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.' %}" + "{%- endif %}" + "{{- '<|start_header_id|>system<|end_header_id|>\\n\\n' }}" + "{{- system_message }}" + "{{- '<|eot_id|>' }}" + "{%- for message in messages %}" + "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}" + "{%- for content in message['content'] | selectattr('type', 'equalto', 'image') %}" + "{{ '<|image|>' }}" + "{%- endfor %}" + "{%- for content in message['content'] | selectattr('type', 'equalto', 'video') %}" + "{{ '<|video|>' }}" + "{%- endfor %}" + "{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}" + "{{- content['text'] | trim }}" + "{%- endfor %}" + "{{'<|eot_id|>' }}" + "{%- endfor %}" + "{%- if add_generation_prompt %}" + "{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}" + "{%- endif %}" +) + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_weights(state_dict, index_dict, param_count, filename): + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, filename) + print(f"Saved {filename}") + return param_count + + +def write_model( + model_path, + input_base_path, + params, + image_token_id, + safe_serialization=True, + tokenizer=None, + num_shards=None, + push_to_hub=False, +): + print("Converting the model.") + num_shards = 1 + model_params = params.get("model", params) + n_layers = model_params["n_layers"] + n_heads = model_params["n_heads"] + dim = model_params["dim"] + dims_per_head = dim // n_heads + base = model_params.get("rope_theta", 10000.0) + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + context_length = model_params["max_seqlen"] + max_position_embeddings = context_length + tie_word_embeddings = model_params.get("weight_tying", False) + projector_pooling_ratio = model_params.get("pooling_ratio", 1) + + if model_params.get("n_kv_heads", None) is not None: + num_key_value_heads = model_params["n_kv_heads"] # for GQA / MQA + key_value_dim = dims_per_head * num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + key_value_dim = dim + + # permute for sliced rotary + def permute(w, n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + with tempfile.TemporaryDirectory() as tmp_model_path: + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if num_shards == 1: + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load( + os.path.join(input_base_path, "consolidated.pth"), + map_location="cpu", + weights_only=True, + ) + else: + # Sharded + checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")]) + print("Loading in order:", checkpoint_list) + loaded = [ + torch.load( + os.path.join(input_base_path, file), + map_location="cpu", + weights_only=True, + ) + for file in checkpoint_list + ] + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 2}.bin" + assert num_shards == 1, "PerceptionLM does not support sharded weights" + state_dict = { + f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads + ), + f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"], + n_heads=num_key_value_heads, + dim1=key_value_dim, + ), + f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight": loaded[ + f"layers.{layer_i}.attention.wv.weight" + ], + f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight": loaded[ + f"layers.{layer_i}.attention.wo.weight" + ], + f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight": loaded[ + f"layers.{layer_i}.feed_forward.w1.weight" + ], + f"model.language_model.layers.{layer_i}.mlp.down_proj.weight": loaded[ + f"layers.{layer_i}.feed_forward.w2.weight" + ], + f"model.language_model.layers.{layer_i}.mlp.up_proj.weight": loaded[ + f"layers.{layer_i}.feed_forward.w3.weight" + ], + f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[ + f"layers.{layer_i}.attention_norm.weight" + ], + f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"layers.{layer_i}.ffn_norm.weight" + ], + } + state_dict[f"model.language_model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f"Saved {filename}") + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 2}.bin" + + state_dict = { + "model.language_model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.language_model.norm.weight": loaded["norm.weight"], + "model.multi_modal_projector.linear_1.weight": loaded["vision_projector.projector.0.weight"], + "model.multi_modal_projector.linear_2.weight": loaded["vision_projector.projector.2.weight"], + "model.multi_modal_projector.linear_1.bias": loaded["vision_projector.projector.0.bias"], + "model.multi_modal_projector.linear_2.bias": loaded["vision_projector.projector.2.bias"], + } + if not tie_word_embeddings: + state_dict["lm_head.weight"] = loaded["output.weight"] + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f"Saved {filename}") + + filename = f"pytorch_model-{n_layers + 2}-of-{n_layers + 2}.bin" + state_dict = {k.replace("vision_model.", ""): v for k, v in loaded.items() if "vision_model" in k} + vision_params = model_params["vision_model"] + if vision_params["layers"] == 23 and vision_params["width"] == 1024: + architecture = "vit_pe_core_large_patch14_336" + elif vision_params["layers"] == 47 and vision_params["width"] == 1536: + architecture = "vit_pe_core_gigantic_patch14_448" + else: + raise ValueError( + f"Unsupported PE config: {vision_params['layers']} layers and {vision_params['width']} width" + ) + + vision_config = TimmWrapperConfig.from_pretrained( + f"timm/{architecture}.fb", + model_args={ + "embed_dim": vision_params["width"], + "depth": vision_params["layers"], + "img_size": (vision_params["image_size"], vision_params["image_size"]), + "global_pool": "", + "use_post_transformer_norm": vision_params["use_ln_post"], + "init_values": vision_params["ls_init_value"], + "ref_feat_shape": ( + vision_params["image_size"] // vision_params["patch_size"], + vision_params["image_size"] // vision_params["patch_size"], + ), + }, + ) + + perception_encoder = AutoModel.from_config(vision_config) + state_dict = checkpoint_filter_fn(state_dict, perception_encoder) + state_dict = {"model.vision_tower.timm_model." + k: v for k, v in state_dict.items()} + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f"Saved {filename}") + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + ffn_dim_multiplier = model_params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in model_params else 1 + multiple_of = model_params["multiple_of"] if "multiple_of" in model_params else 256 + + bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>") + eos_token_id = [tokenizer.convert_tokens_to_ids(t) for t in ["<|end_of_text|>", "<|eot_id|>"]] + + use_scaled_rope = model_params["use_scaled_rope"] + if use_scaled_rope: + rope_scaling = { + "factor": model_params["rope_scale_factor"] * 1.0, + "low_freq_factor": model_params.get("low_freq_factor", 1.0) * 1.0, + "high_freq_factor": model_params.get("high_freq_factor", 4.0) * 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + } + else: + rope_scaling = None + + text_config = LlamaConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=model_params["n_heads"], + num_hidden_layers=model_params["n_layers"], + rms_norm_eps=model_params["norm_eps"], + num_key_value_heads=num_key_value_heads, + vocab_size=len(tokenizer), + rope_theta=base, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + ) + + config = PerceptionLMConfig( + text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + projector_pooling_ratio=projector_pooling_ratio, + vision_use_cls_token=vision_params["use_cls_token"], + image_token_id=tokenizer.image_token_id, + video_token_id=tokenizer.video_token_id, + ) + + config.save_pretrained(tmp_model_path) + + generation_config = GenerationConfig( + do_sample=False, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + ) + generation_config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + # output_weight = loaded.get("output.weight", None) + del loaded + gc.collect() + + print("Loading the checkpoint in a PerceptionLM model.") + model = PerceptionLMForConditionalGeneration.from_pretrained( + tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + # if not tie_word_embeddings: + # if output_weight is None: + # raise ValueError("Output weight/lm_head is not found in the checkpoint.") + # model.lm_head.load_state_dict({"weight": output_weight}) + + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.bfloat16 + + print("Saving in the Transformers format.") + if push_to_hub: + print("Pushing to the hub.") + model.push_to_hub( + model_path, + safe_serialization=safe_serialization, + private=True, + use_temp_dir=True, + ) + else: + print("Saving to disk.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + + +class Llama3Converter(TikTokenConverter): + def __init__( + self, + vocab_file, + special_tokens=None, + context_length=11520, + **kwargs, + ): + super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs) + tokenizer = self.converted() + + self.converted_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + bos_token="<|begin_of_text|>", + eos_token="<|eot_id|>", + model_input_names=["input_ids", "attention_mask"], + model_max_length=context_length, + clean_up_tokenization_spaces=True, + extra_special_tokens={ + "image_token": "<|image|>", + "video_token": "<|video|>", + "pad_token": "<|end_of_text|>", + }, + ) + self.converted_tokenizer.image_token_id = self.converted_tokenizer.encode( + self.converted_tokenizer.image_token, add_special_tokens=False + )[0] + self.converted_tokenizer.video_token_id = self.converted_tokenizer.encode( + self.converted_tokenizer.video_token, add_special_tokens=False + )[0] + self.update_post_processor(self.converted_tokenizer) + # finer special_tokens_map.json + self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN + self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN + + # We can't do this while building the tokenizer because we have no easy access to the bos token id + def update_post_processor(self, tokenizer): + tokenizer._tokenizer.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single="<|begin_of_text|> $A", + pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1", + special_tokens=[ + ( + "<|begin_of_text|>", + tokenizer.convert_tokens_to_ids("<|begin_of_text|>"), + ), + ], + ), + ] + ) + + +def write_tokenizer( + tokenizer_path, + input_tokenizer_path, + special_tokens=None, + params=None, + push_to_hub=False, +): + print("Converting the tokenizer.") + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + context_length = params["model"]["max_seqlen"] + tokenizer = Llama3Converter( + input_tokenizer_path, + special_tokens, + context_length, + ).converted_tokenizer + + tokenizer.image_token_id = tokenizer.encode(tokenizer.image_token, add_special_tokens=False)[0] + processor_config = { + "pooling_ratio": params["model"]["pooling_ratio"], + "patch_size": params["model"]["vision_model"]["patch_size"], + "processor_class": "PerceptionLMProcessor", + } + tile_size = params["model"]["vision_model"]["image_size"] + + image_preprocessor_config = { + "image_processor_type": "PerceptionLMImageProcessorFast", + "vision_input_type": params["data"]["vision_input_type"], + "tile_size": tile_size, + "max_num_tiles": params["data"]["max_num_tiles"], + "max_frame_tiles": 1, + "size": {"height": tile_size, "width": tile_size}, + "do_resize": True, + "do_rescale": True, + "do_normalize": True, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5], + } + image_preprocessor = PerceptionLMImageProcessorFast(**image_preprocessor_config) + video_preprocessor_config = { + "video_processor_type": "PerceptionLMVideoProcessor", + "size": {"height": tile_size, "width": tile_size}, + } + video_preprocessor = PerceptionLMVideoProcessor(**video_preprocessor_config) + processor = PerceptionLMProcessor( + image_processor=image_preprocessor, + video_processor=video_preprocessor, + tokenizer=tokenizer, + chat_template=CHAT_TEMPLATE, + **processor_config, + ) + + if push_to_hub: + print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.") + processor.push_to_hub(tokenizer_path, private=True, use_temp_dir=True) + else: + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + processor.save_pretrained(tokenizer_path) + return tokenizer + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of Llama weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--push_to_hub", + help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.", + action="store_true", + default=False, + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + default=True, + help="Whether or not to save using `safetensors`.", + ) + parser.add_argument( + "--num_shards", + default=None, + type=int, + help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", + ) + parser.add_argument( + "--special_tokens", + default=None, + type=list[str], + help="The list of special tokens that should be added to the model.", + ) + args = parser.parse_args() + if args.special_tokens is None: + # no special tokens by default + args.special_tokens = DEFAULT_SPECIAL_TOKENS.get("perception_lm", []) + + params = read_json(os.path.join(args.input_dir, "params.json")) + + spm_path = os.path.join(args.input_dir, "tokenizer.model") + tokenizer = write_tokenizer( + args.output_dir, + spm_path, + special_tokens=args.special_tokens, + params=params, + push_to_hub=args.push_to_hub, + ) + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + params=params, + image_token_id=tokenizer.image_token_id, + safe_serialization=args.safe_serialization, + tokenizer=tokenizer, + num_shards=args.num_shards, + push_to_hub=args.push_to_hub, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py b/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py new file mode 100644 index 000000000000..8a5fedc50b42 --- /dev/null +++ b/src/transformers/models/perception_lm/image_processing_perception_lm_fast.py @@ -0,0 +1,306 @@ +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for PerceptionLM.""" + +import math +from functools import reduce +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import ( + BatchFeature, +) +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + get_image_size, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + PILImageResampling, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + from torchvision.transforms import functional as F + + +class PerceptionLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + vision_input_type: str = "thumb+tile" + tile_size: int = 448 + max_num_tiles: int = 36 + + +@add_start_docstrings( + "Constructs a fast PerceptionLM image processor.", +) +class PerceptionLMImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = True + do_convert_rgb = True + size = {"width": 448, "height": 448} # for backward compatibility in tests + valid_kwargs = PerceptionLMFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs]) -> None: + super().__init__(**kwargs) + + @staticmethod + def _factors(n: int): + """Return all factors of a number.""" + return set( + reduce( + list.__add__, + ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), + ) + ) + + def _find_supported_aspect_ratios(self): + """ + This function computes all the allowed aspect ratios for a fixed + number of input chunks. The order of returned items matters for the result of `_fit_image_to_canvas` function. + If tie exists in `_fit_image_to_canvas`, the latter in `_find_supported_aspect_ratios` wins. + + For example, with `num_tiles=5`, it will return: + { + 0.2: [(1, 5)], + 5.0: [(5, 1)], + 0.25: [(1, 4)], + 1.0: [(2, 2), (1, 1)], + 4.0: [(4, 1)], + 0.3333333333333333: [(1, 3)], + 3.0: [(3, 1)], + 0.5: [(1, 2)], + 2.0: [(2, 1)] + } + """ + asp_dict = {} + for chunk_size in range(self.max_num_tiles, 0, -1): + _factors = sorted(self._factors(chunk_size)) + _asp_ratios = [(x, chunk_size // x) for x in _factors] + for ratio in _asp_ratios: + k = ratio[0] / ratio[1] + if k not in asp_dict: + asp_dict[k] = [ratio] + else: + asp_dict[k].append(ratio) + return asp_dict + + def _get_image_height_width( + self, image_width: int, image_height: int, target_width: int, target_height: int + ) -> tuple[int, int]: + """ + Given image width, height and target width, height for the canvas, return the dimensions of how the image would be resized + with aspect ratio preservation. + """ + scale = image_width / image_height + + if scale > 1.0: + # Width is larger than height + + # Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas. + rescaling_factor = min(target_width / image_width, target_height / image_height) + + # Set new width to target width and height to the rescaled height. + new_w = rescaling_factor * image_width + new_h = math.floor(new_w / scale) + + else: + # Height is larger than width + + # Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas. + rescaling_factor = min(target_width / image_width, target_height / image_height) + + # Set new height to target height and width to the rescaled width. + new_h = rescaling_factor * image_height + new_w = math.floor(new_h * scale) + + return new_w, new_h + + def _fit_image_to_canvas(self, img_width: int, img_height: int, tile_size: int): + """ + Given an image width, height and target number of chunks this function will see if the image + can be fit into any of the canvases that can be build from arranging the tiles in a grid. + If the image can be fit onto several canvases, it will return the canvas where the shorter edge + of the image will be largest. + """ + # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None. + optimal_canvas = None + optimal_image_width_height = None + + scale = img_width / img_height + + # Gather all potential supported image resolutions and iterate through them to find best match + potential_arrangements = [ + item for sublist in self._find_supported_aspect_ratios().values() for item in sublist + ] + for n_w, n_h in potential_arrangements: + # Compute the canvas size + canvas_width, canvas_height = n_w * tile_size, n_h * tile_size + + # Check if image can fit into the canvas without downsampling + if canvas_width >= img_width and canvas_height >= img_height: + # If we did not find a good canvas yet, we will use the current one + if optimal_canvas is None: + # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling + optimal_canvas = (n_w, n_h) + optimal_image_width_height = self._get_image_height_width( + image_width=img_width, + image_height=img_height, + target_width=n_w * tile_size, + target_height=n_h * tile_size, + ) + else: + # If we already found an optimal canvas before, we will check if the shorter edge of the image will be larger than the current optimal canvas. + # This means we can potentially upsample the image resolution which is beneficial to performance. + image_width_height = self._get_image_height_width( + image_width=img_width, + image_height=img_height, + target_width=n_w * tile_size, + target_height=n_h * tile_size, + ) + # Llama3V dynamic tiling. Priortize biggest canvas. + if (scale < 1.0 and (image_width_height[0] >= optimal_image_width_height[0])) or ( + scale >= 1.0 and (image_width_height[1] >= optimal_image_width_height[1]) + ): + optimal_canvas = (n_w, n_h) + optimal_image_width_height = image_width_height + return optimal_canvas + + def _find_closest_aspect_ratio(self, img_width: int, img_height: int, tile_size: int) -> tuple: + """ + Given an image width, height and target number of chunks + this function will find the closest supported aspect ratio. + """ + target_aspect_ratio = img_width / img_height + asp_dict = self._find_supported_aspect_ratios() + closest_aspect_ratio = None + if target_aspect_ratio >= 1: + closest_aspect_ratio = min( + [k for k in asp_dict.keys() if k <= target_aspect_ratio], + key=lambda x: abs(x - target_aspect_ratio), + ) + tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio] + # select largest width + return max(tiles_given_aspect_ratio, key=lambda x: x[0]) + else: + closest_aspect_ratio = min( + [k for k in asp_dict.keys() if k > target_aspect_ratio], + key=lambda x: abs(1 / x - 1 / target_aspect_ratio), + ) + tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio] + # select largest height + return max(tiles_given_aspect_ratio, key=lambda x: x[1]) + + def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: + # Split image into number of required tiles (width x height) + batch_size, num_channels, height, width = image.size() + image = image.view(batch_size, num_channels, nch, height // nch, ncw, width // ncw) + # Permute dimensions to reorder the axes + image = image.permute(0, 2, 4, 1, 3, 5).contiguous() + # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) + image = image.view(batch_size, ncw * nch, num_channels, height // nch, width // ncw) + return image + + def resize( + self, + image: np.ndarray, + tile_size: int, + max_num_tiles: int, + resample: PILImageResampling = PILImageResampling.BICUBIC, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + height, width = get_image_size(image, channel_dim=input_data_format) + if max_num_tiles > 1: + aspect_ratio = self._fit_image_to_canvas(img_width=width, img_height=height, tile_size=tile_size) + if aspect_ratio is None: + # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image + aspect_ratio = self._find_closest_aspect_ratio(img_width=width, img_height=height, tile_size=tile_size) + else: + aspect_ratio = (1, 1) + new_width, new_height = aspect_ratio[0] * tile_size, aspect_ratio[1] * tile_size + image = F.resize(image, (new_height, new_width), interpolation=resample) + return image, aspect_ratio + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + do_rescale: Optional[bool], + rescale_factor: Optional[Union[int, float]], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + tile_size: int, + max_num_tiles: int, + return_tensors: Optional[Union[str, TensorType]], + disable_grouping: bool, + **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs], + ) -> BatchFeature: + # Group images by size for batched transformation + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + if self.vision_input_type == "thumb+tile": + thumbnails, _ = self.resize(stacked_images, tile_size, max_num_tiles=1) + images_for_tiling, (tiles_w, tiles_h) = self.resize( + stacked_images, tile_size, max_num_tiles=max_num_tiles + ) + image_tiles = self._split(images_for_tiling, tiles_w, tiles_h) + stacked_images = torch.cat([thumbnails.unsqueeze(1), image_tiles], dim=1) + else: # vanilla single tile for low memory devices + stacked_images, _ = self.resize(stacked_images, tile_size, max_num_tiles=1) + + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, + do_rescale, + rescale_factor, + do_normalize, + image_mean, + image_std, + ) + processed_images_grouped[shape] = stacked_images + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["PerceptionLMImageProcessorFast"] diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py new file mode 100644 index 000000000000..b7507b1343fa --- /dev/null +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -0,0 +1,505 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/perception_lm/modular_perception_lm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_perception_lm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...generation.utils import GenerationMixin +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, can_return_tuple +from ..auto import AutoModel +from .configuration_perception_lm import PerceptionLMConfig + + +class PerceptionLMAdaptiveAvgPooling(nn.Module): + def __init__(self, pooling_ratio=2): + super().__init__() + self.pooling_ratio = pooling_ratio + + def forward(self, hidden_states): + b, num_tokens, c = hidden_states.shape + h = int(math.sqrt(num_tokens)) + if h * h != num_tokens: + raise ValueError(f"num_tokens {num_tokens} is expected to be a square number") + + shape = (h // self.pooling_ratio, h // self.pooling_ratio) + hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h) + hidden_states = F.adaptive_avg_pool2d(hidden_states, shape) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + return hidden_states + + +class PerceptionLMMultiModalProjector(nn.Module): + def __init__(self, config: PerceptionLMConfig): + super().__init__() + input_size = config.vision_config.model_args["embed_dim"] + output_size = config.text_config.hidden_size + self.linear_1 = nn.Linear( + in_features=input_size, + out_features=output_size, + bias=True, + ) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear( + in_features=output_size, + out_features=output_size, + bias=True, + ) + self.pooling = ( + PerceptionLMAdaptiveAvgPooling(config.projector_pooling_ratio) + if config.projector_pooling_ratio > 1 + else nn.Identity() + ) + + def forward(self, features): + features = features.permute(1, 0, 2) # NLD -> LND + features = self.linear_1(features) + features = self.gelu(features) + features = self.linear_2(features) + features = features.permute(1, 0, 2) # LND -> NLD + features = self.pooling(features) + return features + + +@auto_docstring +class PerceptionLMPreTrainedModel(PreTrainedModel): + config_class = PerceptionLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of PerceptionLM isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/PerceptionLM/tree/main/perception_lm should serve for that purpose + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PerceptionLM outputs, with hidden states and attentions. + """ +) +class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PerceptionLM causal language model (or autoregressive) outputs. + """ +) +class PerceptionLMCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring +class PerceptionLMModel(PerceptionLMPreTrainedModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: PerceptionLMConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + self.multi_modal_projector = PerceptionLMMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values.flatten(0, 1)) + image_outputs = image_outputs.last_hidden_state + if self.config.vision_use_cls_token: + image_outputs = image_outputs[:, 1:, :] + image_features = self.multi_modal_projector(image_outputs) + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[tuple, PerceptionLMModelOutputWithPast]: + """ + Forward pass of the PerceptionLM model. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Indices of input sequence tokens in the vocabulary. + pixel_values (`torch.FloatTensor`, *optional*): + Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + attention_mask (`torch.Tensor`, *optional*): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor`, *optional*): + Indices of positions of each input sequence token in the position embeddings. + past_key_values (`list[torch.FloatTensor]`, *optional*): + Precomputed key and value hidden states for fast autoregressive generation. + inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. + use_cache (`bool`, *optional*): + Whether or not to use past key values to speed up decoding. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + cache_position (`torch.LongTensor`, *optional*): + Position indices for caching. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + Number of logits to keep. + **lm_kwargs: + Additional keyword arguments for the language model. + + Returns: + [`PerceptionLMModelOutputWithPast`] or `tuple`: + Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: + raise ValueError( + "You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values.to(inputs_embeds), + ) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + self.check_mask_feature_size_match(special_image_mask, image_features) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + video_features = None + if pixel_values_videos is not None: + video_features = self.get_image_features( + pixel_values=pixel_values_videos.to(inputs_embeds), + ) + special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) + self.check_mask_feature_size_match(special_video_mask, video_features) + special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + video_features = video_features.to(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + return PerceptionLMModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=(video_features if pixel_values_videos is not None else None), + ) + + def check_mask_feature_size_match(self, media_mask, media_features): + media_token_count = media_mask.sum() + media_feature_size = media_features.size()[:-1].numel() + if media_token_count != media_feature_size: + raise ValueError( + f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})" + ) + + +@auto_docstring +class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PerceptionLMConfig, **super_kwargs): + super().__init__(config, **super_kwargs) + self.model = PerceptionLMModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + # Make modules available throught conditional class for BC with test_sdpa_can_dispatch_composite_models + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + def set_input_embeddings(self, new_embeddings): + self.model.set_input_embeddings(new_embeddings) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_values_videos=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + return model_inputs + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: + """ + Forward pass for the PerceptionLMForConditionalGeneration model. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Indices of input sequence tokens in the vocabulary. + pixel_values (`torch.FloatTensor`, *optional*): + Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + attention_mask (`torch.Tensor`, *optional*): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor`, *optional*): + Indices of positions of each input sequence token in the position embeddings. + past_key_values (`list[torch.FloatTensor]`, *optional*): + Precomputed key and value hidden states for fast autoregressive generation. + inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. + labels (`torch.LongTensor`, *optional*): + Labels for computing the language modeling loss. + use_cache (`bool`, *optional*): + Whether or not to use past key values to speed up decoding. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + cache_position (`torch.LongTensor`, *optional*): + Position indices for caching. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + Number of logits to keep. + **lm_kwargs: + Additional keyword arguments for the language model. + + Returns: + [`PerceptionLMCausalLMOutputWithPast`] or `tuple`: + Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.text_config.vocab_size, + **lm_kwargs, + ) + + return PerceptionLMCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, + ) + + +__all__ = ["PerceptionLMForConditionalGeneration", "PerceptionLMPreTrainedModel", "PerceptionLMModel"] diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py new file mode 100644 index 000000000000..1b001fea6752 --- /dev/null +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -0,0 +1,430 @@ +# coding=utf-8 +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PerceptionLM model.""" + +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...generation.utils import GenerationMixin +from ...utils import ( + auto_docstring, + can_return_tuple, + logging, +) +from ..auto import AutoModel +from ..llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaModel, + LlavaModelOutputWithPast, + LlavaPreTrainedModel, +) +from .configuration_perception_lm import PerceptionLMConfig + + +logger = logging.get_logger(__name__) + + +class PerceptionLMAdaptiveAvgPooling(nn.Module): + def __init__(self, pooling_ratio=2): + super().__init__() + self.pooling_ratio = pooling_ratio + + def forward(self, hidden_states): + b, num_tokens, c = hidden_states.shape + h = int(math.sqrt(num_tokens)) + if h * h != num_tokens: + raise ValueError(f"num_tokens {num_tokens} is expected to be a square number") + + shape = (h // self.pooling_ratio, h // self.pooling_ratio) + hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h) + hidden_states = F.adaptive_avg_pool2d(hidden_states, shape) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + return hidden_states + + +class PerceptionLMMultiModalProjector(nn.Module): + def __init__(self, config: PerceptionLMConfig): + super().__init__() + input_size = config.vision_config.model_args["embed_dim"] + output_size = config.text_config.hidden_size + self.linear_1 = nn.Linear( + in_features=input_size, + out_features=output_size, + bias=True, + ) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear( + in_features=output_size, + out_features=output_size, + bias=True, + ) + self.pooling = ( + PerceptionLMAdaptiveAvgPooling(config.projector_pooling_ratio) + if config.projector_pooling_ratio > 1 + else nn.Identity() + ) + + def forward(self, features): + features = features.permute(1, 0, 2) # NLD -> LND + features = self.linear_1(features) + features = self.gelu(features) + features = self.linear_2(features) + features = features.permute(1, 0, 2) # LND -> NLD + features = self.pooling(features) + return features + + +class PerceptionLMPreTrainedModel(LlavaPreTrainedModel): + base_model_prefix = "model" + + +class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast): + video_hidden_states: Optional[torch.FloatTensor] = None + + +class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + video_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring +class PerceptionLMModel(LlavaModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: PerceptionLMConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + self.multi_modal_projector = PerceptionLMMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values.flatten(0, 1)) + image_outputs = image_outputs.last_hidden_state + if self.config.vision_use_cls_token: + image_outputs = image_outputs[:, 1:, :] + image_features = self.multi_modal_projector(image_outputs) + return image_features + + def check_mask_feature_size_match(self, media_mask, media_features): + media_token_count = media_mask.sum() + media_feature_size = media_features.size()[:-1].numel() + if media_token_count != media_feature_size: + raise ValueError( + f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})" + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[tuple, PerceptionLMModelOutputWithPast]: + """ + Forward pass of the PerceptionLM model. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Indices of input sequence tokens in the vocabulary. + pixel_values (`torch.FloatTensor`, *optional*): + Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + attention_mask (`torch.Tensor`, *optional*): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor`, *optional*): + Indices of positions of each input sequence token in the position embeddings. + past_key_values (`list[torch.FloatTensor]`, *optional*): + Precomputed key and value hidden states for fast autoregressive generation. + inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. + use_cache (`bool`, *optional*): + Whether or not to use past key values to speed up decoding. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + cache_position (`torch.LongTensor`, *optional*): + Position indices for caching. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + Number of logits to keep. + **lm_kwargs: + Additional keyword arguments for the language model. + + Returns: + [`PerceptionLMModelOutputWithPast`] or `tuple`: + Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: + raise ValueError( + "You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values.to(inputs_embeds), + ) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + self.check_mask_feature_size_match(special_image_mask, image_features) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + video_features = None + if pixel_values_videos is not None: + video_features = self.get_image_features( + pixel_values=pixel_values_videos.to(inputs_embeds), + ) + special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) + self.check_mask_feature_size_match(special_video_mask, video_features) + special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + video_features = video_features.to(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + return PerceptionLMModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + video_hidden_states=(video_features if pixel_values_videos is not None else None), + ) + + +@auto_docstring +class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PerceptionLMConfig, **super_kwargs): + super().__init__(config, **super_kwargs) + self.model = PerceptionLMModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + # Make modules available throught conditional class for BC with test_sdpa_can_dispatch_composite_models + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + def set_input_embeddings(self, new_embeddings): + self.model.set_input_embeddings(new_embeddings) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_values_videos=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + return model_inputs + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]: + """ + Forward pass for the PerceptionLMForConditionalGeneration model. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Indices of input sequence tokens in the vocabulary. + pixel_values (`torch.FloatTensor`, *optional*): + Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`. + pixel_values_videos (`torch.FloatTensor`, *optional*): + Input video tensor of shape `(batch_size, num_frames, channels, height, width)`. + attention_mask (`torch.Tensor`, *optional*): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor`, *optional*): + Indices of positions of each input sequence token in the position embeddings. + past_key_values (`list[torch.FloatTensor]`, *optional*): + Precomputed key and value hidden states for fast autoregressive generation. + inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation. + labels (`torch.LongTensor`, *optional*): + Labels for computing the language modeling loss. + use_cache (`bool`, *optional*): + Whether or not to use past key values to speed up decoding. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + cache_position (`torch.LongTensor`, *optional*): + Position indices for caching. + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): + Number of logits to keep. + **lm_kwargs: + Additional keyword arguments for the language model. + + Returns: + [`PerceptionLMCausalLMOutputWithPast`] or `tuple`: + Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.text_config.vocab_size, + **lm_kwargs, + ) + + return PerceptionLMCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + video_hidden_states=outputs.video_hidden_states, + ) + + +__all__ = [ + "PerceptionLMForConditionalGeneration", + "PerceptionLMPreTrainedModel", + "PerceptionLMModel", +] diff --git a/src/transformers/models/perception_lm/processing_perception_lm.py b/src/transformers/models/perception_lm/processing_perception_lm.py new file mode 100644 index 000000000000..7dc1dc1ea371 --- /dev/null +++ b/src/transformers/models/perception_lm/processing_perception_lm.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for PerceptionLM. +""" + +from typing import Iterable, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, get_image_size, to_numpy_array +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ...video_utils import VideoInput + + +logger = logging.get_logger(__name__) + + +class PerceptionLMProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class PerceptionLMProcessor(ProcessorMixin): + r""" + Constructs a PerceptionLM processor which wraps a PerceptionLM image processor, a PerceptionLM video processor, and a tokenizer into a single processor. + + [`PerceptionLMProcessor`] offers all the functionalities of [`PerceptionLMImageProcessorFast`], [`PerceptionLMVideoProcessor`], and the tokenizer (e.g. [`LlamaTokenizerFast`]). See the + [`~PerceptionLMProcessor.__call__`] and [`~PerceptionLMProcessor.decode`] for more information. + + Args: + video_processor ([`PerceptionLMVideoProcessor`], *optional*): + The video processor to process video inputs. + image_processor ([`PerceptionLMImageProcessorFast`], *optional*): + The image processor to process image inputs. + tokenizer ([`LlamaTokenizerFast`] or similar, *optional*): + The tokenizer to process text inputs. + patch_size (`int`, *optional*): + Patch size from the vision tower. + chat_template (`str`, *optional*): + A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. + pooling_ratio (`int`, *optional*, defaults to 2): + Pooling ratio for vision tokens. If not 1, 2D adaptive pooling is applied over projected vision tokens. + """ + + attributes = ["video_processor", "image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + video_processor=None, + image_processor=None, + tokenizer=None, + patch_size=None, + chat_template=None, + pooling_ratio=2, + **kwargs, + ): + self.patch_size = patch_size + self.pooling_ratio = pooling_ratio + self.image_token = tokenizer.image_token + self.video_token = tokenizer.video_token + self.image_token_id = tokenizer.image_token_id + self.video_token_id = tokenizer.video_token_id + super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio=None, + videos: VideoInput = None, + **kwargs: Unpack[PerceptionLMProcessorKwargs], + ) -> BatchFeature: + """ + Prepares a batch containing one or more sequences of text and/or images and/or videos. + + If `text` is provided, it is tokenized using the tokenizer. + If `images` is provided, they are processed using the image processor. + If `videos` is provided, they are processed using the video processor. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): + The image or batch of images to be processed. Each image can be a PIL image, NumPy array, or PyTorch tensor. + Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, *optional*): + The sequence or batch of sequences to be tokenized. Each sequence can be a string. + videos (`Any`, *optional*): + The video or batch of videos to be processed. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is provided. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is provided). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is provided. + - **pixel_values_videos** -- Video pixel values to be fed to a model. Returned when `videos` is provided. + """ + if text is None: + raise ValueError( + "You have to specify at least `text` input. Optionally, you can also specify `images` or `videos`." + ) + + output_kwargs = self._merge_kwargs( + PerceptionLMProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if videos is not None: + videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"]) + else: + videos_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = [] + + pixel_values = iter(image_inputs.get("pixel_values", [])) + pixel_values_videos = iter(videos_inputs.get("pixel_values_videos", [])) + for sample in text: + # Replace the media token with the expanded media token sequence + sample = self._expand_media_tokens(sample, self.tokenizer.image_token, pixel_values) + sample = self._expand_media_tokens(sample, self.tokenizer.video_token, pixel_values_videos) + prompt_strings.append(sample) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image", "video"]) + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable): + media_count = sample.count(media_token) + if media_count > 0: + media_list = [next(media_iter) for _ in range(media_count)] + sample_splits = sample.split(media_token) + media_token_list = [] + for media in media_list: + height, width = get_image_size(to_numpy_array(media)) + num_tiles = media.shape[0] + num_media_tokens = ( + (height // self.patch_size // self.pooling_ratio) + * (width // self.patch_size // self.pooling_ratio) + * num_tiles + ) + media_token_list.append(num_media_tokens) + sample = "" + for i, num_media_tokens in enumerate(media_token_list): + sample += sample_splits[i] + sample += media_token * num_media_tokens + sample += sample_splits[-1] + return sample + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["PerceptionLMProcessor"] diff --git a/src/transformers/models/perception_lm/video_processing_perception_lm.py b/src/transformers/models/perception_lm/video_processing_perception_lm.py new file mode 100644 index 000000000000..7381045c1d7c --- /dev/null +++ b/src/transformers/models/perception_lm/video_processing_perception_lm.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Video processor class for PerceptionLM.""" + +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, +) +from ...processing_utils import Unpack, VideosKwargs +from ...utils import is_vision_available +from ...utils.import_utils import requires +from ...video_processing_utils import ( + BaseVideoProcessor, +) + + +if is_vision_available(): + from ...image_utils import PILImageResampling + + +class PerceptionLMFastVideoProcessorInitKwargs(VideosKwargs): ... + + +@requires(backends=("torchvision",)) +class PerceptionLMVideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 448, "width": 448} + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = PerceptionLMFastVideoProcessorInitKwargs + model_input_names = ["pixel_values_videos"] + + def __init__(self, **kwargs: Unpack[PerceptionLMFastVideoProcessorInitKwargs]): + super().__init__(**kwargs) + + +__all__ = ["PerceptionLMVideoProcessor"] diff --git a/tests/models/perception_lm/__init__.py b/tests/models/perception_lm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/perception_lm/test_image_processing_perception_lm.py b/tests/models/perception_lm/test_image_processing_perception_lm.py new file mode 100644 index 000000000000..8d6d95e89dc0 --- /dev/null +++ b/tests/models/perception_lm/test_image_processing_perception_lm.py @@ -0,0 +1,224 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + if is_torchvision_available(): + from transformers import PerceptionLMImageProcessorFast + + +class PerceptionLMImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + tile_size=16, + do_normalize=True, + image_mean=IMAGENET_STANDARD_MEAN, + image_std=IMAGENET_STANDARD_STD, + do_convert_rgb=True, + max_num_tiles=4, + vision_input_type="thumb+tile", + resample=Image.Resampling.BICUBIC, # dummy value + size={"shortest_edge": 20}, # dummy value + ): + super().__init__() + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.tile_size = tile_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.max_num_tiles = max_num_tiles + self.vision_input_type = vision_input_type + self.resample = resample + self.size = size + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "tile_size": self.tile_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + "max_num_tiles": self.max_num_tiles, + "vision_input_type": self.vision_input_type, + "resample": self.resample, + "size": self.size, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.crop_size["height"], self.crop_size["width"] + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class PerceptionLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + fast_image_processing_class = PerceptionLMImageProcessorFast if is_torchvision_available() else None + test_slow_image_processor = False + + def setUp(self): + super().setUp() + self.image_processor_tester = PerceptionLMImageProcessingTester(self) + + @property + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "tile_size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "max_num_tiles")) + self.assertTrue(hasattr(image_processing, "vision_input_type")) + + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.tile_size, 16) + self.assertEqual(image_processor.max_num_tiles, 4) + self.assertEqual(image_processor.vision_input_type, "thumb+tile") + + image_processor = image_processing_class.from_dict( + self.image_processor_dict, tile_size=42, max_num_tiles=9 + ) + self.assertEqual(image_processor.tile_size, 42) + self.assertEqual(image_processor.max_num_tiles, 9) + self.assertEqual(image_processor.vision_input_type, "thumb+tile") + + def test_call_pil(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pytorch(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + @unittest.skip(reason="PerceptionLMImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") + def test_call_numpy_4_channels(self): + pass + + def test_nested_input(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 5, 3, 16, 16) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + + # Image processor should return same pixel values, independently of ipnut format + self.assertTrue((encoded_images_nested == encoded_images).all()) diff --git a/tests/models/perception_lm/test_modeling_perception_lm.py b/tests/models/perception_lm/test_modeling_perception_lm.py new file mode 100644 index 000000000000..16f521d70f60 --- /dev/null +++ b/tests/models/perception_lm/test_modeling_perception_lm.py @@ -0,0 +1,474 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch PerceptionLM model.""" + +import unittest + +from huggingface_hub import hf_hub_download + +from transformers import ( + AutoProcessor, + PerceptionLMConfig, + PerceptionLMForConditionalGeneration, + PerceptionLMModel, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_bitsandbytes, + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + +class PerceptionLMVisionText2TextModelTester: + def __init__( + self, + parent, + image_token_id=0, + video_token_id=2, + seq_length=7, + tie_word_embeddings=True, + projector_pooling_ratio=1, + text_config={ + "model_type": "llama", + "seq_length": 7, + "is_training": True, + "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 1, + }, + is_training=True, + vision_config={ + "architecture": "vit_pe_core_large_patch14_336", + "model_args": { + "embed_dim": 64, + "img_size": (14, 14), + "depth": 2, + "global_pool": "", + "use_post_transformer_norm": False, + "init_values": 0.1, + "ref_feat_shape": (1, 1), + }, + }, + ): + self.parent = parent + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.text_config = text_config + self.vision_config = vision_config + self.pad_token_id = text_config["pad_token_id"] + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + self.tie_word_embeddings = tie_word_embeddings + + self.batch_size = 3 + self.num_tiles = 1 + self.num_frames = 1 + self.num_channels = 3 + self.image_size = self.vision_config["model_args"]["img_size"][0] + self.num_image_tokens = (self.vision_config["model_args"]["img_size"][0] // 14) ** 2 + self.num_video_tokens = (self.vision_config["model_args"]["img_size"][0] // 14) ** 2 + self.seq_length = seq_length + self.num_image_tokens + self.encoder_seq_length = self.seq_length + + def get_config(self): + return PerceptionLMConfig( + text_config=self.text_config, + vision_config=self.vision_config, + vision_use_cls_token=True, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.num_tiles, + self.num_channels, + self.vision_config["model_args"]["img_size"][0], + self.vision_config["model_args"]["img_size"][1], + ] + ) + pixel_values_videos = floats_tensor( + [ + self.batch_size, + self.num_frames, + self.num_channels, + self.vision_config["model_args"]["img_size"][0], + self.vision_config["model_args"]["img_size"][1], + ] + ) + config = self.get_config() + + return config, pixel_values, pixel_values_videos + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs() + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 + attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) + input_ids[input_ids == config.image_token_id] = self.pad_token_id + input_ids[input_ids == config.video_token_id] = self.pad_token_id + input_ids[:, : self.num_image_tokens] = config.image_token_id + input_ids[:, self.num_image_tokens : self.num_video_tokens + self.num_image_tokens] = config.video_token_id + + inputs_dict = { + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class PerceptionLMForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `PerceptionLMForConditionalGeneration`. + """ + + all_model_classes = ( + ( + PerceptionLMModel, + PerceptionLMForConditionalGeneration, + ) + if is_torch_available() + else () + ) + test_pruning = False + test_head_masking = False + _is_composite = True + + def setUp(self): + self.model_tester = PerceptionLMVisionText2TextModelTester(self) + common_properties = [ + "image_token_id", + "video_token_id", + ] + self.config_tester = ConfigTester( + self, + config_class=PerceptionLMConfig, + has_text_modality=False, + common_properties=common_properties, + ) + + def test_config(self): + self.config_tester.run_common_tests() + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + del inputs["pixel_values_videos"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + del inputs["pixel_values_videos"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs through an error with explicit message saying what is wrong + when number of images doesn't match number of image tokens in the text. + Also we need to test multi-image cases when one prompr has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + if model_class == PerceptionLMModel: + continue + model = model_class(config).to(torch_device) + _ = model(**input_dict) # successful forward with no modifications + + # remove one image but leave the image token in text + input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = input_dict["input_ids"][:1] + pixel_values = input_dict["pixel_values"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + def test_training(self): + self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else () + super().test_training() + + def test_training_gradient_checkpointing(self): + self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else () + super().test_training_gradient_checkpointing() + + def test_training_gradient_checkpointing_use_reentrant(self): + self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else () + super().test_training_gradient_checkpointing_use_reentrant() + + def test_training_gradient_checkpointing_use_reentrant_false(self): + self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else () + super().test_training_gradient_checkpointing_use_reentrant_false() + + @unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights") + def test_can_init_all_missing_weights(self): + pass + + @unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights") + def test_initialization(self): + pass + + @unittest.skip( + reason="PE/TIMM's attention implementation is self configured and won't raise ValueError on global attention implementation." + ) + def test_flash_attn_2_can_dispatch_composite_models(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("ViT PE / TimmWrapperModel cannot be tested with meta device") + def test_can_be_initialized_on_meta(self): + pass + + @unittest.skip("ViT PE / TimmWrapperModel cannot be tested with meta device") + def test_can_load_with_meta_device_context_manager(self): + pass + + @unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM") + def test_generate_from_inputs_embeds_0_greedy(self): + pass + + @unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM") + def test_generate_from_inputs_embeds_1_beam_search(self): + pass + + @unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + ## Skip flash attention releated tests below + ## correct configuration: + ## from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2", "vision_config": "eager"} + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_eager_matches_fa2_generate(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_from_config(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_eager_matches_sdpa_generate_with_dynamic_cache(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_eager_matches_sdpa_generate(self): + pass + + @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.") + def test_flash_attn_2_inference_equivalence(self): + pass + + +TEST_MODEL_PATH = "shumingh/plm_1b_hf" + + +@require_torch +class PerceptionLMForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH) + self.image_file = hf_hub_download( + repo_id="shumingh/perception_lm_test_images", + filename="14496_0.PNG", + repo_type="dataset", + ) + self.video_file = hf_hub_download( + repo_id="shumingh/perception_lm_test_videos", + filename="GUWR5TyiY-M_000012_000022.mp4", + repo_type="dataset", + ) + self.conversation1 = [ + { + "role": "user", + "content": [ + {"type": "image", "url": self.image_file}, + {"type": "text", "text": "Describe the bar plot in the image."}, + ], + } + ] + self.conversation2 = [ + { + "role": "user", + "content": [ + { + "type": "video", + "url": self.video_file, + }, + {"type": "text", "text": "Can you describe the video in detail?"}, + ], + } + ] + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + @require_bitsandbytes + def test_small_model_integration_test(self): + model = PerceptionLMForConditionalGeneration.from_pretrained( + TEST_MODEL_PATH, load_in_4bit=True, cache_dir="./" + ) + + inputs = self.processor.apply_chat_template( + [self.conversation1], + num_frames=32, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + video_load_backend="decord", + padding=True, + padding_side="left", + ).to(torch_device) + + generate_ids = model.generate(**inputs, max_new_tokens=18) + input_length = inputs["input_ids"].shape[1] + generate_ids_without_inputs = generate_ids[:, input_length:] + + EXPECTED_DECODED_TEXT = "The bar plot displays the values of four categories: step, horror, mood, and lumber" # fmt: skip + + self.assertEqual( + self.processor.decode(generate_ids_without_inputs[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batched(self): + model = PerceptionLMForConditionalGeneration.from_pretrained(TEST_MODEL_PATH, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH) + inputs = processor.apply_chat_template( + [self.conversation1, self.conversation2], + num_frames=32, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + video_load_backend="decord", + padding=True, + padding_side="left", + ).to(torch_device) + + generate_ids = model.generate(**inputs, max_new_tokens=18) + input_length = inputs["input_ids"].shape[1] + generate_ids_without_inputs = generate_ids[:, input_length:] + + EXPECTED_DECODED_TEXT = ['The bar plot displays the values of four categories: step, horror, mood, and lumber', 'The video shows a group of people in green shirts and white shorts performing a jump rope routine'] # fmt: skip + + self.assertEqual( + processor.batch_decode(generate_ids_without_inputs, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_generation_no_images(self): + # model_id = "facebook/Perception-LM-1B" + model = PerceptionLMForConditionalGeneration.from_pretrained(TEST_MODEL_PATH, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH) + + # Prepare inputs with no images + inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20) diff --git a/tests/models/perception_lm/test_processor_perception_lm.py b/tests/models/perception_lm/test_processor_perception_lm.py new file mode 100644 index 000000000000..0d9f6b4162f6 --- /dev/null +++ b/tests/models/perception_lm/test_processor_perception_lm.py @@ -0,0 +1,145 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import shutil +import tempfile +import unittest + +from transformers import ( + AutoProcessor, + AutoTokenizer, + PerceptionLMProcessor, +) +from transformers.testing_utils import require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import PerceptionLMImageProcessorFast, PerceptionLMVideoProcessor + +if is_torch_available(): + import torch + + +TEST_MODEL_PATH = "shumingh/plm_1b_hf" + + +@require_vision +class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = PerceptionLMProcessor + + @classmethod + def setUpClass(cls): + cls.tmpdirname = tempfile.mkdtemp() + + image_processor = PerceptionLMImageProcessorFast( + tile_size=448, max_num_tiles=4, vision_input_type="thumb+tile" + ) + video_processor = PerceptionLMVideoProcessor() + tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_PATH) + tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>", "<|video|>"]}) + processor_kwargs = cls.prepare_processor_dict() + processor = PerceptionLMProcessor( + image_processor=image_processor, video_processor=video_processor, tokenizer=tokenizer, **processor_kwargs + ) + processor.save_pretrained(cls.tmpdirname) + cls.image_token_id = processor.image_token_id + cls.video_token_id = processor.video_token_id + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + + @staticmethod + def prepare_processor_dict(): + return { + "chat_template": CHAT_TEMPLATE, + "patch_size": 14, + "pooling_ratio": 2, + } # fmt: skip + + def test_chat_template_is_saved(self): + processor_loaded = self.processor_class.from_pretrained(self.tmpdirname) + processor_dict_loaded = json.loads(processor_loaded.to_json_string()) + # chat templates aren't serialized to json in processors + self.assertFalse("chat_template" in processor_dict_loaded.keys()) + + # they have to be saved as separate file and loaded back from that file + # so we check if the same template is loaded + processor_dict = self.prepare_processor_dict() + self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + + def test_image_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + # Important to check with non square image + image = torch.randn((1, 3, 450, 500)) + # 5 tiles (thumbnail tile + 4 tiles) + # 448/patch_size/pooling_ratio = 16 => 16*16 tokens per tile + expected_image_tokens = 16 * 16 * 5 + image_token_index = processor.image_token_id + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens) + + +CHAT_TEMPLATE = ( + "{{- bos_token }}" + "{%- if messages[0]['role'] == 'system' -%}" + " {%- set system_message = messages[0]['content']|trim %}\n" + " {%- set messages = messages[1:] %}\n" + "{%- else %}" + " {%- set system_message = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.' %}" + "{%- endif %}" + "{{- '<|start_header_id|>system<|end_header_id|>\\n\\n' }}" + "{{- system_message }}" + "{{- '<|eot_id|>' }}" + "{%- for message in messages %}" + "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}" + "{%- for content in message['content'] | selectattr('type', 'equalto', 'image') %}" + "{{ '<|image|>' }}" + "{%- endfor %}" + "{%- for content in message['content'] | selectattr('type', 'equalto', 'video') %}" + "{{ '<|video|>' }}" + "{%- endfor %}" + "{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}" + "{{- content['text'] | trim }}" + "{%- endfor %}" + "{{'<|eot_id|>' }}" + "{%- endfor %}" + "{%- if add_generation_prompt %}" + "{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}" + "{%- endif %}" +) diff --git a/tests/models/perception_lm/test_video_processing_perception_lm.py b/tests/models/perception_lm/test_video_processing_perception_lm.py new file mode 100644 index 000000000000..f411bc8bc85c --- /dev/null +++ b/tests/models/perception_lm/test_video_processing_perception_lm.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs + + +if is_torch_available(): + pass + +if is_vision_available(): + if is_torchvision_available(): + from transformers import PerceptionLMVideoProcessor + + +class PerceptionLMVideoProcessingTester: + def __init__( + self, + parent, + batch_size=5, + num_frames=8, + num_channels=3, + min_resolution=30, + max_resolution=80, + do_resize=True, + size=None, + do_center_crop=True, + crop_size=None, + do_normalize=True, + image_mean=IMAGENET_STANDARD_MEAN, + image_std=IMAGENET_STANDARD_STD, + do_convert_rgb=True, + ): + size = size if size is not None else {"height": 20, "width": 20} + crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_frames = num_frames + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_video_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + def expected_output_video_shape(self, images): + return self.num_frames, self.num_channels, self.crop_size["height"], self.crop_size["width"] + + def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"): + videos = prepare_video_inputs( + batch_size=self.batch_size, + num_frames=self.num_frames, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + return_tensors=return_tensors, + ) + return videos + + +@require_torch +@require_vision +class PerceptionLMVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase): + fast_video_processing_class = PerceptionLMVideoProcessor if is_torchvision_available() else None + + def setUp(self): + super().setUp() + self.video_processor_tester = PerceptionLMVideoProcessingTester(self) + + @property + def video_processor_dict(self): + return self.video_processor_tester.prepare_video_processor_dict() + + def test_video_processor_properties(self): + video_processing = self.fast_video_processing_class(**self.video_processor_dict) + self.assertTrue(hasattr(video_processing, "do_resize")) + self.assertTrue(hasattr(video_processing, "size")) + self.assertTrue(hasattr(video_processing, "do_center_crop")) + self.assertTrue(hasattr(video_processing, "center_crop")) + self.assertTrue(hasattr(video_processing, "do_normalize")) + self.assertTrue(hasattr(video_processing, "image_mean")) + self.assertTrue(hasattr(video_processing, "image_std")) + self.assertTrue(hasattr(video_processing, "do_convert_rgb")) + + def test_video_processor_from_dict_with_kwargs(self): + video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict) + self.assertEqual(video_processor.size, {"height": 20, "width": 20}) + self.assertEqual(video_processor.crop_size, {"height": 18, "width": 18}) + + video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42, crop_size=84) + self.assertEqual(video_processor.size, {"height": 42, "width": 42}) + self.assertEqual(video_processor.crop_size, {"height": 84, "width": 84})